{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Trustworthy #-}
module Network.DNS
(
queryA
, queryAAAA
, queryCNAME
, queryPTR
, querySRV
, queryTXT
, query
, DnsException(..)
, resIsReentrant
, queryRaw
, sendRaw
, mkQueryRaw
, decodeMessage
, encodeMessage
, mkQueryMsg
, Label
, Labels(..)
, IsLabels(..)
, Name(..)
, caseFoldName
, CharStr(..)
, IPv4(..), arpaIPv4
, IPv6(..), arpaIPv6
, TTL(..)
, Class(..)
, classIN
, Type(..)
, TypeSym(..)
, typeFromSym
, typeToSym
, Msg(..)
, MsgHeader(..)
, MsgHeaderFlags(..), QR(..)
, MsgQuestion(..)
, MsgRR(..)
, RData(..)
, rdType
, SRV(..)
)
where
import Control.Exception
import Data.Bits (unsafeShiftR, (.&.))
import Data.Typeable (Typeable)
import Foreign.C
import Foreign.Marshal.Alloc
import Numeric (showInt)
import Prelude
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC
import Compat
import Network.DNS.FFI
import Network.DNS.Message
data DnsException = DnsEncodeException
| DnsDecodeException
deriving (Show, Typeable)
instance Exception DnsException
query :: IsLabels n => Class -> n -> TypeSym -> IO (Msg n)
query cls name0 qtype
| Just name <- toName name0 = do
bs <- queryRaw cls name (typeFromSym qtype)
msg <- evaluate (decodeMessage bs)
maybe (throwIO DnsDecodeException) pure msg
| otherwise = throwIO DnsEncodeException
queryRaw :: Class -> Name -> Type -> IO BS.ByteString
queryRaw (Class cls) (Name name) qtype = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCString name $ \dn -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_query stptr dn (fromIntegral cls) qtypeVal resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_query(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_query"
fail "res_query(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
qtypeVal :: CInt
qtypeVal = case qtype of Type w -> fromIntegral w
sendRaw :: BS.ByteString -> IO BS.ByteString
sendRaw req = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCStringLen req $ \(reqptr,reqlen) -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_send stptr reqptr (fromIntegral reqlen) resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_send(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_send"
fail "res_send(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
mkQueryMsg :: IsLabels n => Class -> n -> Type -> Msg n
mkQueryMsg cls l qtype = Msg (MsgHeader{..})
[MsgQuestion l qtype cls]
[]
[]
[MsgRR {..}]
where
mhId = 31337
mhFlags = MsgHeaderFlags
{ mhQR = IsQuery
, mhOpcode = 0
, mhAA = False
, mhTC = False
, mhRD = True
, mhRA = False
, mhZ = False
, mhAD = True
, mhCD = False
, mhRCode = 0
}
mhQDCount = 1
mhANCount = 0
mhNSCount = 0
mhARCount = 1
rrName = fromLabels Root
rrClass = Class 512
rrTTL = TTL 0x8000
rrData = RDataOPT ""
mkQueryRaw :: Class -> Name -> Type -> IO BS.ByteString
mkQueryRaw (Class cls) (Name name) qtype = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCString name $ \dn -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_mkquery stptr dn (fromIntegral cls) qtypeVal resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_mkquery(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_query"
fail "res_mkquery(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
qtypeVal :: CInt
qtypeVal = case qtype of Type w -> fromIntegral w
caseFoldName :: Name -> Name
caseFoldName (Name n) = (Name n'')
where
n' = BS.map cf n
n'' | BS.null n' = "."
| BS.last n' == 0x2e = n'
| otherwise = n' `mappend` "."
cf w | 0x61 <= w && w <= 0x7a = w - 0x20
| otherwise = w
queryA :: Name -> IO [(TTL,IPv4)]
queryA n = do
res <- query classIN n' TypeA
pure [ (ttl,ip4) | MsgRR { rrData = RDataA ip4, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryAAAA :: Name -> IO [(TTL,IPv6)]
queryAAAA n = do
res <- query classIN n' TypeAAAA
pure [ (ttl,ip6) | MsgRR { rrData = RDataAAAA ip6, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryCNAME :: Name -> IO [(TTL,Name)]
queryCNAME n = do
res <- query classIN n' TypeAAAA
pure [ (ttl,cname) | MsgRR { rrData = RDataCNAME cname, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryPTR :: Name -> IO [(TTL,Name)]
queryPTR n = do
res <- query classIN n' TypePTR
pure [ (ttl,ptrs) | MsgRR { rrData = RDataPTR ptrs, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryTXT :: Name -> IO [(TTL,[CharStr])]
queryTXT n = do
res <- query classIN n' TypeTXT
pure [ (ttl,txts) | MsgRR { rrData = RDataTXT txts, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
querySRV :: Name -> IO [(TTL,SRV Name)]
querySRV n = do
res <- query classIN n' TypeSRV
pure [ (ttl,srv) | MsgRR { rrData = RDataSRV srv, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
arpaIPv4 :: IPv4 -> Name
arpaIPv4 (IPv4 w) = Name (BSC.pack s)
where
s = showInt o0 ('.' : showInt o1 ('.' : showInt o2 ('.' : showInt o3 ".in-addr.arpa.")))
o0, o1, o2, o3 :: Word8
o0 = fromIntegral $ w
o1 = fromIntegral $ w `unsafeShiftR` 8
o2 = fromIntegral $ w `unsafeShiftR` 16
o3 = fromIntegral $ w `unsafeShiftR` 24
arpaIPv6 :: IPv6 -> Name
arpaIPv6 (IPv6 hi lo) = Name (BSC.pack s)
where
s = go 16 lo (go 16 hi "ip6.arpa.")
go :: Int -> Word64 -> ShowS
go 0 _ cont = cont
go n w cont = nib : '.' : go (n-1) w' cont
where
nib :: Char
nib | x < 10 = toEnum (fromIntegral (0x30 + x))
| otherwise = toEnum (fromIntegral (0x57 + x))
x = w .&. 0xf
w' = w `unsafeShiftR` 4