{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}

module Network.DNS.Transport (
    Resolver(..)
  , resolve
  ) where

import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)

import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types.Internal
import Network.DNS.Types.Resolver

-- | Check response for a matching identifier and question.  If we ever do
-- pipelined TCP, we'll need to handle out of order responses.  See:
-- https://tools.ietf.org/html/rfc7766#section-7
--
checkResp :: Question -> Identifier -> DNSMessage -> Bool
checkResp :: Question -> Identifier -> DNSMessage -> Bool
checkResp Question
q Identifier
seqno = Maybe DNSError -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe DNSError -> Bool)
-> (DNSMessage -> Maybe DNSError) -> DNSMessage -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
seqno

-- When the response 'RCODE' is 'FormatErr', the server did not understand our
-- query packet, and so is not expected to return a matching question.
--
checkRespM :: Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM :: Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
seqno DNSMessage
resp
  | DNSHeader -> Identifier
identifier (DNSMessage -> DNSHeader
header DNSMessage
resp) Identifier -> Identifier -> Bool
forall a. Eq a => a -> a -> Bool
/= Identifier
seqno = DNSError -> Maybe DNSError
forall a. a -> Maybe a
Just DNSError
SequenceNumberMismatch
  | RCODE
FormatErr <- DNSFlags -> RCODE
rcode (DNSFlags -> RCODE) -> DNSFlags -> RCODE
forall a b. (a -> b) -> a -> b
$ DNSHeader -> DNSFlags
flags (DNSHeader -> DNSFlags) -> DNSHeader -> DNSFlags
forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
resp
  , []        <- DNSMessage -> [Question]
question DNSMessage
resp        = Maybe DNSError
forall a. Maybe a
Nothing
  | [Question
q] [Question] -> [Question] -> Bool
forall a. Eq a => a -> a -> Bool
/= DNSMessage -> [Question]
question DNSMessage
resp              = DNSError -> Maybe DNSError
forall a. a -> Maybe a
Just DNSError
QuestionMismatch
  | Bool
otherwise                         = Maybe DNSError
forall a. Maybe a
Nothing

----------------------------------------------------------------

data TCPFallback = TCPFallback deriving (Int -> TCPFallback -> ShowS
[TCPFallback] -> ShowS
TCPFallback -> String
(Int -> TCPFallback -> ShowS)
-> (TCPFallback -> String)
-> ([TCPFallback] -> ShowS)
-> Show TCPFallback
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TCPFallback] -> ShowS
$cshowList :: [TCPFallback] -> ShowS
show :: TCPFallback -> String
$cshow :: TCPFallback -> String
showsPrec :: Int -> TCPFallback -> ShowS
$cshowsPrec :: Int -> TCPFallback -> ShowS
Show, Typeable)
instance Exception TCPFallback

type Rslv0 = QueryControls -> (Socket -> IO DNSMessage)
           -> IO (Either DNSError DNSMessage)

type Rslv1 = Question
          -> Int -- Timeout
          -> Int -- Retry
          -> Rslv0

type TcpRslv = AddrInfo
            -> Question
            -> Int -- Timeout
            -> QueryControls
            -> IO DNSMessage

type UdpRslv = Int -- Retry
            -> (Socket -> IO DNSMessage)
            -> TcpRslv

-- In lookup loop, we try UDP until we get a response.  If the response
-- is truncated, we try TCP once, with no further UDP retries.
--
-- For now, we optimize for low latency high-availability caches
-- (e.g.  running on a loopback interface), where TCP is cheap
-- enough.  We could attempt to complete the TCP lookup within the
-- original time budget of the truncated UDP query, by wrapping both
-- within a a single 'timeout' thereby staying within the original
-- time budget, but it seems saner to give TCP a full opportunity to
-- return results.  TCP latency after a truncated UDP reply will be
-- atypical.
--
-- Future improvements might also include support for TCP on the
-- initial query.
--
-- This function merges the query flag overrides from the resolver
-- configuration with any additional overrides from the caller.
--
resolve :: Resolver -> Domain -> TYPE -> Rslv0
resolve :: Resolver -> Domain -> TYPE -> Rslv0
resolve Resolver
rlv Domain
dom TYPE
typ QueryControls
qctls Socket -> IO DNSMessage
rcv
  | Domain -> Bool
isIllegal Domain
dom = Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either DNSError DNSMessage -> IO (Either DNSError DNSMessage))
-> Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall a b. (a -> b) -> a -> b
$ DNSError -> Either DNSError DNSMessage
forall a b. a -> Either a b
Left DNSError
IllegalDomain
  | TYPE
typ TYPE -> TYPE -> Bool
forall a. Eq a => a -> a -> Bool
== TYPE
AXFR   = Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either DNSError DNSMessage -> IO (Either DNSError DNSMessage))
-> Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall a b. (a -> b) -> a -> b
$ DNSError -> Either DNSError DNSMessage
forall a b. a -> Either a b
Left DNSError
InvalidAXFRLookup
  | Bool
onlyOne       = AddrInfo -> IO Identifier -> Rslv1
resolveOne        ([AddrInfo] -> AddrInfo
forall a. [a] -> a
head [AddrInfo]
nss) ([IO Identifier] -> IO Identifier
forall a. [a] -> a
head [IO Identifier]
gens) Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
  | Bool
concurrent    = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss        [IO Identifier]
gens        Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
  | Bool
otherwise     = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss        [IO Identifier]
gens        Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
  where
    q :: Question
q = case Domain -> Char
BS.last Domain
dom of
          Char
'.' -> Domain -> TYPE -> Question
Question Domain
dom TYPE
typ
          Char
_   -> Domain -> TYPE -> Question
Question (Domain
dom Domain -> Domain -> Domain
forall a. Semigroup a => a -> a -> a
<> Domain
".") TYPE
typ

    gens :: [IO Identifier]
gens = NonEmpty (IO Identifier) -> [IO Identifier]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (IO Identifier) -> [IO Identifier])
-> NonEmpty (IO Identifier) -> [IO Identifier]
forall a b. (a -> b) -> a -> b
$ Resolver -> NonEmpty (IO Identifier)
genIds Resolver
rlv

    seed :: ResolvSeed
seed    = Resolver -> ResolvSeed
resolvseed Resolver
rlv
    nss :: [AddrInfo]
nss     = NonEmpty AddrInfo -> [AddrInfo]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty AddrInfo -> [AddrInfo])
-> NonEmpty AddrInfo -> [AddrInfo]
forall a b. (a -> b) -> a -> b
$ ResolvSeed -> NonEmpty AddrInfo
nameservers ResolvSeed
seed
    onlyOne :: Bool
onlyOne = [AddrInfo] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [AddrInfo]
nss Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
    ctls :: QueryControls
ctls    = QueryControls
qctls QueryControls -> QueryControls -> QueryControls
forall a. Semigroup a => a -> a -> a
<> ResolvConf -> QueryControls
resolvQueryControls (ResolvSeed -> ResolvConf
resolvconf (ResolvSeed -> ResolvConf) -> ResolvSeed -> ResolvConf
forall a b. (a -> b) -> a -> b
$ Resolver -> ResolvSeed
resolvseed Resolver
rlv)

    conf :: ResolvConf
conf       = ResolvSeed -> ResolvConf
resolvconf ResolvSeed
seed
    concurrent :: Bool
concurrent = ResolvConf -> Bool
resolvConcurrent ResolvConf
conf
    tm :: Int
tm         = ResolvConf -> Int
resolvTimeout ResolvConf
conf
    retry :: Int
retry      = ResolvConf -> Int
resolvRetry ResolvConf
conf


resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss [IO Identifier]
gs Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv = [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
nss [IO Identifier]
gs
  where
    loop :: [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo
ai]     [IO Identifier
gen] = AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
    loop (AddrInfo
ai:[AddrInfo]
ais) (IO Identifier
gen:[IO Identifier]
gens) = do
        Either DNSError DNSMessage
eres <- AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
        case Either DNSError DNSMessage
eres of
          Left  DNSError
_ -> [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
ais [IO Identifier]
gens
          Either DNSError DNSMessage
res     -> Either DNSError DNSMessage -> IO (Either DNSError DNSMessage)
forall (m :: * -> *) a. Monad m => a -> m a
return Either DNSError DNSMessage
res
    loop [AddrInfo]
_  [IO Identifier]
_     = String -> IO (Either DNSError DNSMessage)
forall a. HasCallStack => String -> a
error String
"resolveSequential:loop"

resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss [IO Identifier]
gens Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv = do
    [Async (Either DNSError DNSMessage)]
asyncs <- ((AddrInfo, IO Identifier)
 -> IO (Async (Either DNSError DNSMessage)))
-> [(AddrInfo, IO Identifier)]
-> IO [Async (Either DNSError DNSMessage)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync ([(AddrInfo, IO Identifier)]
 -> IO [Async (Either DNSError DNSMessage)])
-> [(AddrInfo, IO Identifier)]
-> IO [Async (Either DNSError DNSMessage)]
forall a b. (a -> b) -> a -> b
$ [AddrInfo] -> [IO Identifier] -> [(AddrInfo, IO Identifier)]
forall a b. [a] -> [b] -> [(a, b)]
zip [AddrInfo]
nss [IO Identifier]
gens
    (Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
-> Either DNSError DNSMessage
forall a b. (a, b) -> b
snd ((Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
 -> Either DNSError DNSMessage)
-> IO
     (Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
-> IO (Either DNSError DNSMessage)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Async (Either DNSError DNSMessage)]
-> IO
     (Async (Either DNSError DNSMessage), Either DNSError DNSMessage)
forall a. [Async a] -> IO (Async a, a)
waitAnyCancel [Async (Either DNSError DNSMessage)]
asyncs
  where
    mkAsync :: (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync (AddrInfo
ai,IO Identifier
gen) = IO (Either DNSError DNSMessage)
-> IO (Async (Either DNSError DNSMessage))
forall a. IO a -> IO (Async a)
async (IO (Either DNSError DNSMessage)
 -> IO (Async (Either DNSError DNSMessage)))
-> IO (Either DNSError DNSMessage)
-> IO (Async (Either DNSError DNSMessage))
forall a b. (a -> b) -> a -> b
$ AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv

resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv =
    IO DNSMessage -> IO (Either DNSError DNSMessage)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO DNSMessage -> IO (Either DNSError DNSMessage))
-> IO DNSMessage -> IO (Either DNSError DNSMessage)
forall a b. (a -> b) -> a -> b
$ IO Identifier -> UdpRslv
udpTcpLookup IO Identifier
gen Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls

----------------------------------------------------------------

-- UDP attempts must use the same ID and accept delayed answers
-- but we use a fresh ID for each TCP lookup.
--
udpTcpLookup :: IO Identifier -> UdpRslv
udpTcpLookup :: IO Identifier -> UdpRslv
udpTcpLookup IO Identifier
gen Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls = do
    Identifier
ident <- IO Identifier
gen
    Identifier -> UdpRslv
udpLookup Identifier
ident Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls IO DNSMessage -> (TCPFallback -> IO DNSMessage) -> IO DNSMessage
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch`
            \TCPFallback
TCPFallback -> IO Identifier -> TcpRslv
tcpLookup IO Identifier
gen AddrInfo
ai Question
q Int
tm QueryControls
ctls

----------------------------------------------------------------

ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
protoName IOError
ioe = DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
throwIO (DNSError -> IO DNSMessage) -> DNSError -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$ IOError -> DNSError
NetworkFailure IOError
aioe
  where
    loc :: String
loc = String
protoName String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"@" String -> ShowS
forall a. [a] -> [a] -> [a]
++ SockAddr -> String
forall a. Show a => a -> String
show (AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
    aioe :: IOError
aioe = IOError -> String -> Maybe Handle -> Maybe String -> IOError
annotateIOError IOError
ioe String
loc Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

----------------------------------------------------------------

udpOpen :: AddrInfo -> IO Socket
udpOpen :: AddrInfo -> IO Socket
udpOpen AddrInfo
ai = do
    Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
ai) (AddrInfo -> SocketType
addrSocketType AddrInfo
ai) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
ai)
    Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
    Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

-- This throws DNSError or TCPFallback.
udpLookup :: Identifier -> UdpRslv
udpLookup :: Identifier -> UdpRslv
udpLookup Identifier
ident Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls = do
    let qry :: Domain
qry = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
ctls
    (IOError -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"udp") (IO DNSMessage -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$
      IO Socket
-> (Socket -> IO ()) -> (Socket -> IO DNSMessage) -> IO DNSMessage
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (AddrInfo -> IO Socket
udpOpen AddrInfo
ai) Socket -> IO ()
close (Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
ctls Int
0 DNSError
RetryLimitExceeded)
  where
    loop :: Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
lctls Int
cnt DNSError
err Socket
sock
      | Int
cnt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
retry = DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
err
      | Bool
otherwise    = do
          Maybe DNSMessage
mres <- Int -> IO DNSMessage -> IO (Maybe DNSMessage)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm (Socket -> Domain -> IO ()
send Socket
sock Domain
qry IO () -> IO DNSMessage -> IO DNSMessage
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO DNSMessage
getAns Socket
sock)
          case Maybe DNSMessage
mres of
              Maybe DNSMessage
Nothing  -> Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
lctls (Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) DNSError
RetryLimitExceeded Socket
sock
              Just DNSMessage
res -> do
                      let fl :: DNSFlags
fl = DNSHeader -> DNSFlags
flags (DNSHeader -> DNSFlags) -> DNSHeader -> DNSFlags
forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
res
                          tc :: Bool
tc = DNSFlags -> Bool
trunCation DNSFlags
fl
                          rc :: RCODE
rc = DNSFlags -> RCODE
rcode DNSFlags
fl
                          eh :: EDNSheader
eh = DNSMessage -> EDNSheader
ednsHeader DNSMessage
res
                          cs :: QueryControls
cs = FlagOp -> QueryControls
ednsEnabled FlagOp
FlagClear QueryControls -> QueryControls -> QueryControls
forall a. Semigroup a => a -> a -> a
<> QueryControls
lctls
                      if Bool
tc then TCPFallback -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO TCPFallback
TCPFallback
                      else if RCODE
rc RCODE -> RCODE -> Bool
forall a. Eq a => a -> a -> Bool
== RCODE
FormatErr Bool -> Bool -> Bool
&& EDNSheader
eh EDNSheader -> EDNSheader -> Bool
forall a. Eq a => a -> a -> Bool
== EDNSheader
NoEDNS Bool -> Bool -> Bool
&& QueryControls
cs QueryControls -> QueryControls -> Bool
forall a. Eq a => a -> a -> Bool
/= QueryControls
lctls
                      then let qry' :: Domain
qry' = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
cs
                            in Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry' QueryControls
cs Int
cnt DNSError
RetryLimitExceeded Socket
sock
                      else DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res

    -- | Closed UDP ports are occasionally re-used for a new query, with
    -- the nameserver returning an unexpected answer to the wrong socket.
    -- Such answers should be simply dropped, with the client continuing
    -- to wait for the right answer, without resending the question.
    -- Note, this eliminates sequence mismatch as a UDP error condition,
    -- instead we'll time out if no matching answer arrives.
    --
    getAns :: Socket -> IO DNSMessage
getAns Socket
sock = do
        DNSMessage
resp <- Socket -> IO DNSMessage
rcv Socket
sock
        if Question -> Identifier -> DNSMessage -> Bool
checkResp Question
q Identifier
ident DNSMessage
resp
        then DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
resp
        else Socket -> IO DNSMessage
getAns Socket
sock

----------------------------------------------------------------

-- Create a TCP socket with the given socket address.
tcpOpen :: SockAddr -> IO Socket
tcpOpen :: SockAddr -> IO Socket
tcpOpen SockAddr
peer = case SockAddr
peer of
    SockAddrInet{}  -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET  SocketType
Stream ProtocolNumber
defaultProtocol
    SockAddrInet6{} -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET6 SocketType
Stream ProtocolNumber
defaultProtocol
    SockAddr
_               -> DNSError -> IO Socket
forall e a. Exception e => e -> IO a
E.throwIO DNSError
ServerFailure

-- Perform a DNS query over TCP, if we were successful in creating
-- the TCP socket.
-- This throws DNSError only.
tcpLookup :: IO Identifier -> TcpRslv
tcpLookup :: IO Identifier -> TcpRslv
tcpLookup IO Identifier
gen AddrInfo
ai Question
q Int
tm QueryControls
ctls =
    (IOError -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"tcp") (IO DNSMessage -> IO DNSMessage) -> IO DNSMessage -> IO DNSMessage
forall a b. (a -> b) -> a -> b
$ do
        DNSMessage
res <- IO Socket
-> (Socket -> IO ()) -> (Socket -> IO DNSMessage) -> IO DNSMessage
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SockAddr -> IO Socket
tcpOpen SockAddr
addr) Socket -> IO ()
close (QueryControls -> Socket -> IO DNSMessage
perform QueryControls
ctls)
        let rc :: RCODE
rc = DNSFlags -> RCODE
rcode (DNSFlags -> RCODE) -> DNSFlags -> RCODE
forall a b. (a -> b) -> a -> b
$ DNSHeader -> DNSFlags
flags (DNSHeader -> DNSFlags) -> DNSHeader -> DNSFlags
forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
res
            eh :: EDNSheader
eh = DNSMessage -> EDNSheader
ednsHeader DNSMessage
res
            cs :: QueryControls
cs = FlagOp -> QueryControls
ednsEnabled FlagOp
FlagClear QueryControls -> QueryControls -> QueryControls
forall a. Semigroup a => a -> a -> a
<> QueryControls
ctls
        -- If we first tried with EDNS, retry without on FormatErr.
        if RCODE
rc RCODE -> RCODE -> Bool
forall a. Eq a => a -> a -> Bool
== RCODE
FormatErr Bool -> Bool -> Bool
&& EDNSheader
eh EDNSheader -> EDNSheader -> Bool
forall a. Eq a => a -> a -> Bool
== EDNSheader
NoEDNS Bool -> Bool -> Bool
&& QueryControls
cs QueryControls -> QueryControls -> Bool
forall a. Eq a => a -> a -> Bool
/= QueryControls
ctls
        then IO Socket
-> (Socket -> IO ()) -> (Socket -> IO DNSMessage) -> IO DNSMessage
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SockAddr -> IO Socket
tcpOpen SockAddr
addr) Socket -> IO ()
close (QueryControls -> Socket -> IO DNSMessage
perform QueryControls
cs)
        else DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res
  where
    addr :: SockAddr
addr = AddrInfo -> SockAddr
addrAddress AddrInfo
ai
    perform :: QueryControls -> Socket -> IO DNSMessage
perform QueryControls
cs Socket
vc = do
        Identifier
ident <- IO Identifier
gen
        let qry :: Domain
qry = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
cs
        Maybe DNSMessage
mres <- Int -> IO DNSMessage -> IO (Maybe DNSMessage)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm (IO DNSMessage -> IO (Maybe DNSMessage))
-> IO DNSMessage -> IO (Maybe DNSMessage)
forall a b. (a -> b) -> a -> b
$ do
            Socket -> SockAddr -> IO ()
connect Socket
vc SockAddr
addr
            Socket -> Domain -> IO ()
sendVC Socket
vc Domain
qry
            Socket -> IO DNSMessage
receiveVC Socket
vc
        case Maybe DNSMessage
mres of
            Maybe DNSMessage
Nothing  -> DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO DNSError
TimeoutExpired
            Just DNSMessage
res -> IO DNSMessage
-> (DNSError -> IO DNSMessage) -> Maybe DNSError -> IO DNSMessage
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DNSMessage -> IO DNSMessage
forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res) DNSError -> IO DNSMessage
forall e a. Exception e => e -> IO a
E.throwIO (Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
ident DNSMessage
res)

----------------------------------------------------------------

badLength :: Domain -> Bool
badLength :: Domain -> Bool
badLength Domain
dom
    | Domain -> Bool
BS.null Domain
dom        = Bool
True
    | Domain -> Char
BS.last Domain
dom Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'.' = Domain -> Int
BS.length Domain
dom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
254
    | Bool
otherwise          = Domain -> Int
BS.length Domain
dom Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
253

isIllegal :: Domain -> Bool
isIllegal :: Domain -> Bool
isIllegal Domain
dom
  | Domain -> Bool
badLength Domain
dom               = Bool
True
  | Char
'.' Char -> Domain -> Bool
`BS.notElem` Domain
dom        = Bool
True
  | Char
':' Char -> Domain -> Bool
`BS.elem` Domain
dom           = Bool
True
  | Char
'/' Char -> Domain -> Bool
`BS.elem` Domain
dom           = Bool
True
  | (Domain -> Bool) -> [Domain] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Domain
x -> Domain -> Int
BS.length Domain
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
63)
        (Char -> Domain -> [Domain]
BS.split Char
'.' Domain
dom)      = Bool
True
  | Bool
otherwise                   = Bool
False