{-# LANGUAGE ScopedTypeVariables, ForeignFunctionInterface #-} module Network.DNS.Common where import Control.Monad (when) import Data.Word import Data.Bits import Data.List (intersperse) import qualified Data.ByteString as B import Data.ByteString.Internal (c2w, w2c) import qualified Data.Binary.Put as P import qualified Data.Binary.Strict.Get as G import qualified Data.Binary.Strict.BitGet as BG import Network.DNS.Types foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32 -- | Types of DNS queries. RFC 1035, 4.1.1. data QueryType = QUERY | IQUERY | SERVERSTATUS deriving (Show, Eq, Enum, Bounded) -- | Parse an enum from a Word8 in a monad and fail if the value is out of range. -- It's assumed that the enum is defined at every value between the min and max -- bound parseEnum :: forall m a. (Enum a, Bounded a, Monad m) => Word8 -> m a parseEnum x' = r where r = if x < low || x > high then fail "Enum out of bounds" else return $ toEnum x low = fromEnum (minBound :: a) high = fromEnum (maxBound :: a) x = fromIntegral x' -- | A DNS protocol header. RFC 1035, 4.1.1. data Header = Header { headId :: Word16 , headIsResponse :: Bool , headOpCode :: QueryType , headIsAuthoritative :: Bool , headIsTruncated :: Bool , headRecursionDesired :: Bool , headRecursionAvailible :: Bool , headResponseCode :: ResponseCode , headQuestionCount :: Int , headAnswerCount :: Int , headNSCount :: Int , headAdditionalCount :: Int } deriving (Show, Eq) parseHeader :: G.Get Header parseHeader = do id <- G.getWord16be flags <- G.getByteString 2 qdcount <- G.getWord16be >>= return . fromIntegral ancount <- G.getWord16be >>= return . fromIntegral nscount <- G.getWord16be >>= return . fromIntegral arcount <- G.getWord16be >>= return . fromIntegral let r = BG.runBitGet flags (do isquery <- BG.getBit opcode <- BG.getAsWord8 4 >>= parseEnum aa <- BG.getBit tc <- BG.getBit rd <- BG.getBit ra <- BG.getBit BG.getAsWord8 3 rcode <- BG.getAsWord8 4 >>= parseEnum return $ Header id isquery opcode aa tc rd ra rcode qdcount ancount nscount arcount) case r of Left error -> fail error Right x -> return x serialiseHeader :: Header -> P.Put serialiseHeader header = do P.putWord16be $ headId header let flags1 = (v (headIsResponse header) `shiftL` 7) .|. (fromEnum (headOpCode header) `shiftL` 3) .|. (v (headIsAuthoritative header) `shiftL` 2) .|. (v (headIsTruncated header) `shiftL` 1) .|. (v (headRecursionDesired header)) flags2 = (v (headRecursionAvailible header) `shiftL` 7) .|. (fromEnum (headResponseCode header)) v True = 1 v False = 0 P.putWord8 $ fromIntegral flags1 P.putWord8 $ fromIntegral flags2 P.putWord16be $ fromIntegral $ headQuestionCount header P.putWord16be $ fromIntegral $ headAnswerCount header P.putWord16be $ fromIntegral $ headNSCount header P.putWord16be $ fromIntegral $ headAdditionalCount header -- | Break a DNS name (e.g. www.google.com) into a list of labels splitDNSName :: String -> [String] splitDNSName = filter (not . null) . split '.' where split c xs = head : tail where (head, rest) = span (/= c) xs tail = case rest of [] -> [] (_:xs) -> split c xs -- | Convert a split name into the length-prefixed DNS wire format. -- FIXME: should work with the IDNA system. Returns Nothing if the -- name couldn't be serialised. -- FIXME: catch invalid charactors and > 255 parts serialiseDNSName :: [String] -> Maybe B.ByteString serialiseDNSName x | length x > 255 = fail "" | otherwise = mapM f x >>= return . (flip B.snoc) 0 . B.concat where f x | length x > 63 = fail "" | otherwise = return $ B.cons lengthByte s where lengthByte = fromIntegral $ length x s = B.pack $ map c2w x -- | Convert a list of labels to a normal string by interspersing periods fromDNSName :: [String] -> String fromDNSName = concat . intersperse "." parseDNSName :: B.ByteString -> G.Get [String] parseDNSName packet = do let getLabel 16 = fail "Pointer loop in DNS name" getLabel depth = do b <- G.getWord8 -- if it's a pointer we need to decode it if b .&. 0xc0 == 0xc0 then do b2 <- G.getWord8 let offset = ((fromIntegral $ b .&. 0x3f) `shiftL` 8) .|. fromIntegral b2 if offset >= B.length packet then fail "Invalid DNS label pointer" else case G.runGet (getLabel (depth + 1)) $ B.drop offset packet of (Left error, _) -> fail error (Right l, _) -> return l else if b == 0 then return [] else do l <- G.getByteString $ fromIntegral b rest <- getLabel depth return $ (map w2c $ B.unpack l) : rest getLabel (0 :: Int) serialiseQuestion :: B.ByteString -- ^ the encoded name (see @serialiseDNSName@) -> DNSType -- ^ the type of the question -> P.Put serialiseQuestion s ty = do P.putByteString s P.putWord16be $ fromIntegral $ fromEnum ty P.putWord16be 1 parseQuestion :: B.ByteString -> G.Get (String, DNSType) parseQuestion packet = do name <- parseDNSName packet ty <- parseDNSType G.getWord16be return (fromDNSName name, ty) parseDNSType :: G.Get DNSType parseDNSType = do ty <- G.getWord16be >>= return . toEnum . fromIntegral case ty of UnknownDNSType -> fail "Unknown DNS type in question" _ -> return ty deserialiseQuestion :: B.ByteString -> G.Get ([String], DNSType) deserialiseQuestion packet = do name <- parseDNSName packet ty <- parseDNSType G.getWord16be return (name, ty) parseGenericRR :: B.ByteString -> G.Get ([String], DNSType, Word32, B.ByteString) parseGenericRR packet = do name <- parseDNSName packet ty <- parseDNSType clas <- G.getWord16be when (clas /= 1) $ fail "Bad class in RR" ttl <- G.getWord32be rlen <- G.getWord16be bytes <- G.getByteString $ fromIntegral rlen return (name, ty, ttl, bytes) type Entry = ([String], Word32, RR) parseRR :: B.ByteString -> G.Get Entry parseRR packet = do (name, ty, ttl, bytes) <- parseGenericRR packet let parseMany :: G.Get a -> G.Get [a] parseMany parser = do emptyp <- G.isEmpty if emptyp then return [] else do v <- parser rest <- parseMany parser return $ v : rest parseIP = G.getWord32be >>= return . htonl parseA = parseMany parseIP parseAAAA = parseMany $ do a <- parseIP b <- parseIP c <- parseIP d <- parseIP return (a, b, c, d) parseName = parseDNSName packet parseMX = parseMany $ do pref <- G.getWord16be name <- parseDNSName packet return (fromIntegral pref, name) parseSOA = do name <- parseDNSName packet rname <- parseDNSName packet serial <- G.getWord32be refresh <- G.getWord32be retry <- G.getWord32be expire <- G.getWord32be minimum <- G.getWord32be return $ RRSOA name rname serial refresh retry expire minimum parseTXT = do length <- G.getWord8 G.getByteString (fromIntegral length) >>= return let parse = case ty of A -> parseA >>= return . RRA NS -> parseName >>= return . RRNS CNAME -> parseName >>= return . RRCNAME SOA -> parseSOA PTR -> parseName >>= return . RRPTR MX -> parseMX >>= return . RRMX TXT -> parseTXT >>= return . RRTXT AAAA -> parseAAAA >>= return . RRAAAA let (err, _) = G.runGet parse bytes case err of Left error -> fail error Right rr -> return (name, ttl, rr) data Packet = Packet Header [(String, DNSType)] [Entry] [Entry] [Entry] deriving (Show) parsePacket :: B.ByteString -> Either String Packet parsePacket input = fst $ G.runGet (do header <- parseHeader a <- sequence $ replicate (headQuestionCount header) $ parseQuestion input b <- sequence $ replicate (headAnswerCount header) $ parseRR input c <- sequence $ replicate (headNSCount header) $ parseRR input d <- sequence $ replicate (headAdditionalCount header) $ parseRR input return $ Packet header a b c d) input