{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Network.DNS.Transport (
Resolver(..)
, resolve
) where
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)
import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types.Internal
import Network.DNS.Types.Resolver
checkResp :: Question -> Identifier -> DNSMessage -> Bool
checkResp :: Question -> Identifier -> DNSMessage -> Bool
checkResp Question
q Identifier
seqno = forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
seqno
checkRespM :: Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM :: Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
seqno DNSMessage
resp
| DNSHeader -> Identifier
identifier (DNSMessage -> DNSHeader
header DNSMessage
resp) forall a. Eq a => a -> a -> Bool
/= Identifier
seqno = forall a. a -> Maybe a
Just DNSError
SequenceNumberMismatch
| RCODE
FormatErr <- DNSFlags -> RCODE
rcode forall a b. (a -> b) -> a -> b
$ DNSHeader -> DNSFlags
flags forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
resp
, [] <- DNSMessage -> [Question]
question DNSMessage
resp = forall a. Maybe a
Nothing
| [Question
q] forall a. Eq a => a -> a -> Bool
/= DNSMessage -> [Question]
question DNSMessage
resp = forall a. a -> Maybe a
Just DNSError
QuestionMismatch
| Bool
otherwise = forall a. Maybe a
Nothing
data TCPFallback = TCPFallback deriving (Int -> TCPFallback -> ShowS
[TCPFallback] -> ShowS
TCPFallback -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TCPFallback] -> ShowS
$cshowList :: [TCPFallback] -> ShowS
show :: TCPFallback -> String
$cshow :: TCPFallback -> String
showsPrec :: Int -> TCPFallback -> ShowS
$cshowsPrec :: Int -> TCPFallback -> ShowS
Show, Typeable)
instance Exception TCPFallback
type Rslv0 = QueryControls -> (Socket -> IO DNSMessage)
-> IO (Either DNSError DNSMessage)
type Rslv1 = Question
-> Int
-> Int
-> Rslv0
type TcpRslv = AddrInfo
-> Question
-> Int
-> QueryControls
-> IO DNSMessage
type UdpRslv = Int
-> (Socket -> IO DNSMessage)
-> TcpRslv
resolve :: Resolver -> Domain -> TYPE -> Rslv0
resolve :: Resolver -> Domain -> TYPE -> Rslv0
resolve Resolver
rlv Domain
dom TYPE
typ QueryControls
qctls Socket -> IO DNSMessage
rcv
| Domain -> Bool
isIllegal Domain
dom = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left DNSError
IllegalDomain
| TYPE
typ forall a. Eq a => a -> a -> Bool
== TYPE
AXFR = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left DNSError
InvalidAXFRLookup
| Bool
onlyOne = AddrInfo -> IO Identifier -> Rslv1
resolveOne (forall a. [a] -> a
head [AddrInfo]
nss) (forall a. [a] -> a
head [IO Identifier]
gens) Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
| Bool
concurrent = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss [IO Identifier]
gens Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
| Bool
otherwise = [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss [IO Identifier]
gens Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
where
q :: Question
q = case Domain -> Char
BS.last Domain
dom of
Char
'.' -> Domain -> TYPE -> Question
Question Domain
dom TYPE
typ
Char
_ -> Domain -> TYPE -> Question
Question (Domain
dom forall a. Semigroup a => a -> a -> a
<> Domain
".") TYPE
typ
gens :: [IO Identifier]
gens = forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ Resolver -> NonEmpty (IO Identifier)
genIds Resolver
rlv
seed :: ResolvSeed
seed = Resolver -> ResolvSeed
resolvseed Resolver
rlv
nss :: [AddrInfo]
nss = forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ ResolvSeed -> NonEmpty AddrInfo
nameservers ResolvSeed
seed
onlyOne :: Bool
onlyOne = forall (t :: * -> *) a. Foldable t => t a -> Int
length [AddrInfo]
nss forall a. Eq a => a -> a -> Bool
== Int
1
ctls :: QueryControls
ctls = QueryControls
qctls forall a. Semigroup a => a -> a -> a
<> ResolvConf -> QueryControls
resolvQueryControls (ResolvSeed -> ResolvConf
resolvconf forall a b. (a -> b) -> a -> b
$ Resolver -> ResolvSeed
resolvseed Resolver
rlv)
conf :: ResolvConf
conf = ResolvSeed -> ResolvConf
resolvconf ResolvSeed
seed
concurrent :: Bool
concurrent = ResolvConf -> Bool
resolvConcurrent ResolvConf
conf
tm :: Int
tm = ResolvConf -> Int
resolvTimeout ResolvConf
conf
retry :: Int
retry = ResolvConf -> Int
resolvRetry ResolvConf
conf
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential [AddrInfo]
nss [IO Identifier]
gs Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv = [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
nss [IO Identifier]
gs
where
loop :: [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo
ai] [IO Identifier
gen] = AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
loop (AddrInfo
ai:[AddrInfo]
ais) (IO Identifier
gen:[IO Identifier]
gens) = do
Either DNSError DNSMessage
eres <- AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
case Either DNSError DNSMessage
eres of
Left DNSError
_ -> [AddrInfo] -> [IO Identifier] -> IO (Either DNSError DNSMessage)
loop [AddrInfo]
ais [IO Identifier]
gens
Either DNSError DNSMessage
res -> forall (m :: * -> *) a. Monad m => a -> m a
return Either DNSError DNSMessage
res
loop [AddrInfo]
_ [IO Identifier]
_ = forall a. HasCallStack => String -> a
error String
"resolveSequential:loop"
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent [AddrInfo]
nss [IO Identifier]
gens Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv = do
[Async (Either DNSError DNSMessage)]
asyncs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [AddrInfo]
nss [IO Identifier]
gens
forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [Async a] -> IO (Async a, a)
waitAnyCancel [Async (Either DNSError DNSMessage)]
asyncs
where
mkAsync :: (AddrInfo, IO Identifier)
-> IO (Async (Either DNSError DNSMessage))
mkAsync (AddrInfo
ai,IO Identifier
gen) = forall a. IO a -> IO (Async a)
async forall a b. (a -> b) -> a -> b
$ AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne AddrInfo
ai IO Identifier
gen Question
q Int
tm Int
retry QueryControls
ctls Socket -> IO DNSMessage
rcv =
forall e a. Exception e => IO a -> IO (Either e a)
E.try forall a b. (a -> b) -> a -> b
$ IO Identifier -> UdpRslv
udpTcpLookup IO Identifier
gen Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls
udpTcpLookup :: IO Identifier -> UdpRslv
udpTcpLookup :: IO Identifier -> UdpRslv
udpTcpLookup IO Identifier
gen Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls = do
Identifier
ident <- IO Identifier
gen
Identifier -> UdpRslv
udpLookup Identifier
ident Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch`
\TCPFallback
TCPFallback -> IO Identifier -> TcpRslv
tcpLookup IO Identifier
gen AddrInfo
ai Question
q Int
tm QueryControls
ctls
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
protoName IOError
ioe = forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ IOError -> DNSError
NetworkFailure IOError
aioe
where
loc :: String
loc = String
protoName forall a. [a] -> [a] -> [a]
++ String
"@" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
aioe :: IOError
aioe = IOError -> String -> Maybe Handle -> Maybe String -> IOError
annotateIOError IOError
ioe String
loc forall a. Maybe a
Nothing forall a. Maybe a
Nothing
udpOpen :: AddrInfo -> IO Socket
udpOpen :: AddrInfo -> IO Socket
udpOpen AddrInfo
ai = do
Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
ai) (AddrInfo -> SocketType
addrSocketType AddrInfo
ai) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
ai)
Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
ai)
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
udpLookup :: Identifier -> UdpRslv
udpLookup :: Identifier -> UdpRslv
udpLookup Identifier
ident Int
retry Socket -> IO DNSMessage
rcv AddrInfo
ai Question
q Int
tm QueryControls
ctls = do
let qry :: Domain
qry = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
ctls
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"udp") forall a b. (a -> b) -> a -> b
$
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (AddrInfo -> IO Socket
udpOpen AddrInfo
ai) Socket -> IO ()
close (Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
ctls Int
0 DNSError
RetryLimitExceeded)
where
loop :: Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
lctls Int
cnt DNSError
err Socket
sock
| Int
cnt forall a. Eq a => a -> a -> Bool
== Int
retry = forall e a. Exception e => e -> IO a
E.throwIO DNSError
err
| Bool
otherwise = do
Maybe DNSMessage
mres <- forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm (Socket -> Domain -> IO ()
send Socket
sock Domain
qry forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO DNSMessage
getAns Socket
sock)
case Maybe DNSMessage
mres of
Maybe DNSMessage
Nothing -> Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry QueryControls
lctls (Int
cnt forall a. Num a => a -> a -> a
+ Int
1) DNSError
RetryLimitExceeded Socket
sock
Just DNSMessage
res -> do
let fl :: DNSFlags
fl = DNSHeader -> DNSFlags
flags forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
res
tc :: Bool
tc = DNSFlags -> Bool
trunCation DNSFlags
fl
rc :: RCODE
rc = DNSFlags -> RCODE
rcode DNSFlags
fl
eh :: EDNSheader
eh = DNSMessage -> EDNSheader
ednsHeader DNSMessage
res
cs :: QueryControls
cs = FlagOp -> QueryControls
ednsEnabled FlagOp
FlagClear forall a. Semigroup a => a -> a -> a
<> QueryControls
lctls
if Bool
tc then forall e a. Exception e => e -> IO a
E.throwIO TCPFallback
TCPFallback
else if RCODE
rc forall a. Eq a => a -> a -> Bool
== RCODE
FormatErr Bool -> Bool -> Bool
&& EDNSheader
eh forall a. Eq a => a -> a -> Bool
== EDNSheader
NoEDNS Bool -> Bool -> Bool
&& QueryControls
cs forall a. Eq a => a -> a -> Bool
/= QueryControls
lctls
then let qry' :: Domain
qry' = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
cs
in Domain
-> QueryControls -> Int -> DNSError -> Socket -> IO DNSMessage
loop Domain
qry' QueryControls
cs Int
cnt DNSError
RetryLimitExceeded Socket
sock
else forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res
getAns :: Socket -> IO DNSMessage
getAns Socket
sock = do
DNSMessage
resp <- Socket -> IO DNSMessage
rcv Socket
sock
if Question -> Identifier -> DNSMessage -> Bool
checkResp Question
q Identifier
ident DNSMessage
resp
then forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
resp
else Socket -> IO DNSMessage
getAns Socket
sock
tcpOpen :: SockAddr -> IO Socket
tcpOpen :: SockAddr -> IO Socket
tcpOpen SockAddr
peer = case SockAddr
peer of
SockAddrInet{} -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
defaultProtocol
SockAddrInet6{} -> Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET6 SocketType
Stream ProtocolNumber
defaultProtocol
SockAddr
_ -> forall e a. Exception e => e -> IO a
E.throwIO DNSError
ServerFailure
tcpLookup :: IO Identifier -> TcpRslv
tcpLookup :: IO Identifier -> TcpRslv
tcpLookup IO Identifier
gen AddrInfo
ai Question
q Int
tm QueryControls
ctls =
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError AddrInfo
ai String
"tcp") forall a b. (a -> b) -> a -> b
$ do
DNSMessage
res <- forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SockAddr -> IO Socket
tcpOpen SockAddr
addr) Socket -> IO ()
close (QueryControls -> Socket -> IO DNSMessage
perform QueryControls
ctls)
let rc :: RCODE
rc = DNSFlags -> RCODE
rcode forall a b. (a -> b) -> a -> b
$ DNSHeader -> DNSFlags
flags forall a b. (a -> b) -> a -> b
$ DNSMessage -> DNSHeader
header DNSMessage
res
eh :: EDNSheader
eh = DNSMessage -> EDNSheader
ednsHeader DNSMessage
res
cs :: QueryControls
cs = FlagOp -> QueryControls
ednsEnabled FlagOp
FlagClear forall a. Semigroup a => a -> a -> a
<> QueryControls
ctls
if RCODE
rc forall a. Eq a => a -> a -> Bool
== RCODE
FormatErr Bool -> Bool -> Bool
&& EDNSheader
eh forall a. Eq a => a -> a -> Bool
== EDNSheader
NoEDNS Bool -> Bool -> Bool
&& QueryControls
cs forall a. Eq a => a -> a -> Bool
/= QueryControls
ctls
then forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SockAddr -> IO Socket
tcpOpen SockAddr
addr) Socket -> IO ()
close (QueryControls -> Socket -> IO DNSMessage
perform QueryControls
cs)
else forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res
where
addr :: SockAddr
addr = AddrInfo -> SockAddr
addrAddress AddrInfo
ai
perform :: QueryControls -> Socket -> IO DNSMessage
perform QueryControls
cs Socket
vc = do
Identifier
ident <- IO Identifier
gen
let qry :: Domain
qry = Identifier -> Question -> QueryControls -> Domain
encodeQuestion Identifier
ident Question
q QueryControls
cs
Maybe DNSMessage
mres <- forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tm forall a b. (a -> b) -> a -> b
$ do
Socket -> SockAddr -> IO ()
connect Socket
vc SockAddr
addr
Socket -> Domain -> IO ()
sendVC Socket
vc Domain
qry
Socket -> IO DNSMessage
receiveVC Socket
vc
case Maybe DNSMessage
mres of
Maybe DNSMessage
Nothing -> forall e a. Exception e => e -> IO a
E.throwIO DNSError
TimeoutExpired
Just DNSMessage
res -> forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
res) forall e a. Exception e => e -> IO a
E.throwIO (Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM Question
q Identifier
ident DNSMessage
res)
badLength :: Domain -> Bool
badLength :: Domain -> Bool
badLength Domain
dom
| Domain -> Bool
BS.null Domain
dom = Bool
True
| Domain -> Char
BS.last Domain
dom forall a. Eq a => a -> a -> Bool
== Char
'.' = Domain -> Int
BS.length Domain
dom forall a. Ord a => a -> a -> Bool
> Int
254
| Bool
otherwise = Domain -> Int
BS.length Domain
dom forall a. Ord a => a -> a -> Bool
> Int
253
isIllegal :: Domain -> Bool
isIllegal :: Domain -> Bool
isIllegal Domain
dom
| Domain -> Bool
badLength Domain
dom = Bool
True
| Char
'.' Char -> Domain -> Bool
`BS.notElem` Domain
dom = Bool
True
| Char
':' Char -> Domain -> Bool
`BS.elem` Domain
dom = Bool
True
| Char
'/' Char -> Domain -> Bool
`BS.elem` Domain
dom = Bool
True
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Domain
x -> Domain -> Int
BS.length Domain
x forall a. Ord a => a -> a -> Bool
> Int
63)
(Char -> Domain -> [Domain]
BS.split Char
'.' Domain
dom) = Bool
True
| Bool
otherwise = Bool
False