{-# LANGUAGE BangPatterns, OverloadedStrings #-}
module Network.DNS.Decode.Parsers (
getResponse
, getDNSFlags
, getHeader
, getResourceRecord
, getResourceRecords
, getDomain
, getMailbox
) where
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS
import qualified Data.IP
import Data.IP (IP(..), toIPv4, toIPv6b, makeAddrRange)
import Data.List (partition)
import Network.DNS.Imports
import Network.DNS.StateBinary
import Network.DNS.Types
getResponse :: SGet DNSMessage
getResponse = do
hm <- getHeader
qdCount <- getInt16
anCount <- getInt16
nsCount <- getInt16
arCount <- getInt16
queries <- getQueries qdCount
answers <- getResourceRecords anCount
authrrs <- getResourceRecords nsCount
addnrrs <- getResourceRecords arCount
let (opts, rest) = partition ((==) OPT. rrtype) addnrrs
flgs = flags hm
rc = fromRCODE $ rcode flgs
(eh, erc) = getEDNS rc opts
hd = hm { flags = flgs { rcode = erc } }
pure $ DNSMessage hd eh queries answers authrrs $ ifEDNS eh rest addnrrs
where
getEDNS :: Word16 -> AdditionalRecords -> (EDNSheader, RCODE)
getEDNS rc rrs = case rrs of
[rr] | Just (edns, erc) <- optEDNS rr
-> (EDNSheader edns, toRCODE erc)
[] -> (NoEDNS, toRCODE rc)
_ -> (InvalidEDNS, BadRCODE)
where
optEDNS :: ResourceRecord -> Maybe (EDNS, Word16)
optEDNS (ResourceRecord "." OPT udpsiz ttl' (RD_OPT opts)) =
let hrc = fromIntegral rc .&. 0x0f
erc = shiftR (ttl' .&. 0xff000000) 20 .|. hrc
secok = ttl' `testBit` 15
vers = fromIntegral $ shiftR (ttl' .&. 0x00ff0000) 16
in Just (EDNS vers udpsiz secok opts, fromIntegral erc)
optEDNS _ = Nothing
getDNSFlags :: SGet DNSFlags
getDNSFlags = do
flgs <- get16
oc <- getOpcode flgs
return $ DNSFlags (getQorR flgs)
oc
(getAuthAnswer flgs)
(getTrunCation flgs)
(getRecDesired flgs)
(getRecAvailable flgs)
(getRcode flgs)
(getAuthenData flgs)
(getChkDisable flgs)
where
getQorR w = if testBit w 15 then QR_Response else QR_Query
getOpcode w =
case shiftR w 11 .&. 0x0f of
n | Just opc <- toOPCODE n
-> pure opc
| otherwise
-> failSGet $ "Unsupported header opcode: " ++ show n
getAuthAnswer w = testBit w 10
getTrunCation w = testBit w 9
getRecDesired w = testBit w 8
getRecAvailable w = testBit w 7
getRcode w = toRCODE $ w .&. 0x0f
getAuthenData w = testBit w 5
getChkDisable w = testBit w 4
getHeader :: SGet DNSHeader
getHeader =
DNSHeader <$> decodeIdentifier <*> getDNSFlags
where
decodeIdentifier = get16
getQueries :: Int -> SGet [Question]
getQueries n = replicateM n getQuery
getTYPE :: SGet TYPE
getTYPE = toTYPE <$> get16
getQuery :: SGet Question
getQuery = Question <$> getDomain
<*> getTYPE
<* ignoreClass
where
ignoreClass = get16
getResourceRecords :: Int -> SGet [ResourceRecord]
getResourceRecords n = replicateM n getResourceRecord
getResourceRecord :: SGet ResourceRecord
getResourceRecord = do
dom <- getDomain
typ <- getTYPE
cls <- get16
ttl <- get32
len <- getInt16
dat <- fitSGet len $ getRData typ len
return $ ResourceRecord dom typ cls ttl dat
rdataEnd :: Int
-> SGet Int
rdataEnd !len = (+) len <$> getPosition
getRData :: TYPE -> Int -> SGet RData
getRData NS _ = RD_NS <$> getDomain
getRData MX _ = RD_MX <$> get16 <*> getDomain
getRData CNAME _ = RD_CNAME <$> getDomain
getRData DNAME _ = RD_DNAME <$> getDomain
getRData TXT len = RD_TXT <$> getTXT len
getRData A _ = RD_A . toIPv4 <$> getNBytes 4
getRData AAAA _ = RD_AAAA . toIPv6b <$> getNBytes 16
getRData SOA _ = RD_SOA <$> getDomain
<*> getMailbox
<*> decodeSerial
<*> decodeRefesh
<*> decodeRetry
<*> decodeExpire
<*> decodeMinimum
where
decodeSerial = get32
decodeRefesh = get32
decodeRetry = get32
decodeExpire = get32
decodeMinimum = get32
getRData PTR _ = RD_PTR <$> getDomain
getRData SRV _ = RD_SRV <$> decodePriority
<*> decodeWeight
<*> decodePort
<*> getDomain
where
decodePriority = get16
decodeWeight = get16
decodePort = get16
getRData OPT len = RD_OPT <$> getOpts len
getRData TLSA len = RD_TLSA <$> decodeUsage
<*> decodeSelector
<*> decodeMType
<*> decodeADF
where
decodeUsage = get8
decodeSelector = get8
decodeMType = get8
decodeADF = getNByteString (len - 3)
getRData DS len = RD_DS <$> decodeTag
<*> decodeAlg
<*> decodeDtyp
<*> decodeDval
where
decodeTag = get16
decodeAlg = get8
decodeDtyp = get8
decodeDval = getNByteString (len - 4)
getRData RRSIG len = RD_RRSIG <$> decodeRRSIG
where
decodeRRSIG = do
end <- rdataEnd len
typ <- getTYPE
alg <- get8
cnt <- get8
ttl <- get32
tex <- getDnsTime
tin <- getDnsTime
tag <- get16
dom <- getDomain
pos <- getPosition
val <- getNByteString $ end - pos
return $ RDREP_RRSIG typ alg cnt ttl tex tin tag dom val
getDnsTime = do
tnow <- getAtTime
tdns <- get32
return $! dnsTime tdns tnow
getRData NULL len = RD_NULL <$> getNByteString len
getRData NSEC len = do
end <- rdataEnd len
dom <- getDomain
pos <- getPosition
RD_NSEC dom <$> getNsecTypes (end - pos)
getRData DNSKEY len = RD_DNSKEY <$> decodeKeyFlags
<*> decodeKeyProto
<*> decodeKeyAlg
<*> decodeKeyBytes
where
decodeKeyFlags = get16
decodeKeyProto = get8
decodeKeyAlg = get8
decodeKeyBytes = getNByteString (len - 4)
getRData NSEC3 len = do
dend <- rdataEnd len
halg <- get8
flgs <- get8
iter <- get16
salt <- getInt8 >>= getNByteString
hash <- getInt8 >>= getNByteString
tpos <- getPosition
RD_NSEC3 halg flgs iter salt hash <$> getNsecTypes (dend - tpos)
getRData NSEC3PARAM _ = RD_NSEC3PARAM <$> decodeHashAlg
<*> decodeFlags
<*> decodeIterations
<*> decodeSalt
where
decodeHashAlg = get8
decodeFlags = get8
decodeIterations = get16
decodeSalt = getInt8 >>= getNByteString
getRData _ len = UnknownRData <$> getNByteString len
getTXT :: Int -> SGet ByteString
getTXT !len = B.concat <$> sGetMany "TXT RR string" len getstring
where
getstring = getInt8 >>= getNByteString
getOpts :: Int -> SGet [OData]
getOpts !len = sGetMany "EDNS option" len getoption
where
getoption = do
code <- toOptCode <$> get16
olen <- getInt16
getOData code olen
getNsecTypes :: Int -> SGet [TYPE]
getNsecTypes !len = concat <$> sGetMany "NSEC type bitmap" len getbits
where
getbits = do
window <- flip shiftL 8 <$> getInt8
blocks <- getInt8
when (blocks > 32) $
failSGet $ "NSEC bitmap block too long: " ++ show blocks
concatMap blkTypes. zip [window, window + 8..] <$> getNBytes blocks
where
blkTypes (bitOffset, byte) =
[ toTYPE $ fromIntegral $ bitOffset + i |
i <- [0..7], byte .&. bit (7-i) /= 0 ]
getOData :: OptCode -> Int -> SGet OData
getOData NSID len = OD_NSID <$> getNByteString len
getOData DAU len = OD_DAU <$> getNoctets len
getOData DHU len = OD_DHU <$> getNoctets len
getOData N3U len = OD_N3U <$> getNoctets len
getOData ClientSubnet len = do
family <- get16
srcBits <- get8
scpBits <- get8
addrbs <- getNByteString (len - 4)
case BS.length addrbs == (fromIntegral srcBits + 7) `div` 8 of
True | Just ip <- bstoip family addrbs srcBits scpBits
-> pure $ OD_ClientSubnet srcBits scpBits ip
_ -> pure $ OD_ECSgeneric family srcBits scpBits addrbs
where
prefix addr bits = Data.IP.addr $ makeAddrRange addr $ fromIntegral bits
zeropad = (++ repeat 0). map fromIntegral. B.unpack
checkBits fromBytes toIP srcBits scpBits bytes =
let addr = fromBytes bytes
maskedAddr = prefix addr srcBits
maxBits = fromIntegral $ 8 * length bytes
in if addr == maskedAddr && scpBits <= maxBits
then Just $ toIP addr
else Nothing
bstoip :: Word16 -> B.ByteString -> Word8 -> Word8 -> Maybe IP
bstoip family bs srcBits scpBits = case family of
1 -> checkBits toIPv4 IPv4 srcBits scpBits $ take 4 $ zeropad bs
2 -> checkBits toIPv6b IPv6 srcBits scpBits $ take 16 $ zeropad bs
_ -> Nothing
getOData opc len = UnknownOData (fromOptCode opc) <$> getNByteString len
getDomain :: SGet Domain
getDomain = getPosition >>= getDomain' dot
getMailbox :: SGet Mailbox
getMailbox = getPosition >>= getDomain' atsign
dot, atsign :: Word8
dot = fromIntegral $ fromEnum '.'
atsign = fromIntegral $ fromEnum '@'
getDomain' :: Word8 -> Int -> SGet ByteString
getDomain' sep1 ptrLimit = do
pos <- getPosition
c <- getInt8
let n = getValue c
getdomain pos c n
where
getdomain pos c n
| c == 0 = return "."
| isPointer c = do
d <- getInt8
let offset = n * 256 + d
when (offset >= ptrLimit) $
failSGet "invalid name compression pointer"
mo <- pop offset
case mo of
Nothing -> do
msg <- getInput
let parser = skipNBytes offset >> getDomain' sep1 offset
case runSGet parser msg of
Left (DecodeError err) -> failSGet err
Left err -> fail $ show err
Right o -> push pos (fst o) >> return (fst o)
Just o -> push pos o >> return o
| isExtLabel c = return ""
| otherwise = do
hs <- unparseLabel sep1 <$> getNByteString n
ds <- getDomain' dot ptrLimit
let dom = case ds of
"." -> hs <> "."
_ -> hs <> B.singleton sep1 <> ds
push pos dom
return dom
getValue c = c .&. 0x3f
isPointer c = testBit c 7 && testBit c 6
isExtLabel c = not (testBit c 7) && testBit c 6