module Network.DNS
( HostName
, aHostName
, hostName
, hostNameLabels
, arpaHostName
, HostAddr(..)
, Host4Addr
, Host6Addr
, aHostAddr
, aHostAddrOf
, aHost4Addr
, aHost6Addr
, aHostAddrIP
, DnsId
, DnsType(..)
, dnsTypeCode
, DnsData(..)
, DnsRecord(..)
, DnsQType(..)
, dnsQTypeCode
, DnsQuestion(..)
, DnsReq(..)
, DnsError(..)
, DnsResp(..)
) where
import Data.Typeable (Typeable)
#if !MIN_VERSION_base(4,7,0)
import Data.Typeable (Typeable1)
#endif
import Data.Proxy (Proxy(..))
import Data.Foldable (forM_)
import Data.Hashable
import Data.Word
import Data.Bits
import Data.Char (chr, ord)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Builder as BB
import qualified Data.Binary.Get as B
import Data.Serializer (Serializer, Serializable, SizedSerializable)
import qualified Data.Serializer as S
import Data.Deserializer (Deserializer, Deserializable)
import qualified Data.Deserializer as D
import Text.Parser.Combinators as P
import Text.Parser.Char as P
import Text.Printer ((<>))
import qualified Text.Printer as T
import Data.Textual (Printable, toAscii, toUtf8, Textual)
import qualified Data.Textual as T
import qualified Text.Ascii as A
import Text.Printf
import qualified Text.Read as TR
import Network.IP.Addr
import Control.Applicative ((<$>), Applicative(..), (<|>))
import Control.Monad (void, unless, ap, foldM)
newtype HostName = HN {
hostName ∷ ByteString
}
deriving (Typeable, Eq, Ord, Hashable)
aHostName ∷ Proxy HostName
aHostName = Proxy
instance Show HostName where
showsPrec p (HN bs) = showParen (p > 10)
$ showString "fromJust "
. (showParen True $
showString "fromString "
. showsPrec 10 (BS8.unpack bs))
instance Read HostName where
readPrec = TR.parens $ TR.prec 10 $ do
TR.Ident "fromJust" ← TR.lexP
TR.step $ TR.parens $ TR.prec 10 $ do
TR.Ident "fromString" ← TR.lexP
TR.String s ← TR.lexP
Just n ← return $ T.fromString s
return n
instance Printable HostName where
print (HN bs) = T.ascii bs
instance Textual HostName where
textual = go [] (0 ∷ Int) False [] (0 ∷ Int) <?> "host name"
where alphaNumOrDashOrDot c = A.isAlphaNum c || c == '-' || c == '.'
go !ls !ncs _ _ 0 =
optional (satisfy A.isAlpha) >>= \case
Just c → if ncs == 255
then unexpected "Host name is too long"
else go ls (ncs + 1) False [A.ascii c] 1
Nothing → unexpected "A letter expected"
go !ls !ncs !dash !lcs !nlcs =
optional (satisfy alphaNumOrDashOrDot) >>= \case
Just '.' → if dash
then unexpected "Label ends with a dash"
else if ncs == 255
then unexpected "Host name is too long"
else go (reverse (A.ascii '.' : lcs) : ls)
(ncs + 1) False [] 0
Just c → if nlcs == 63
then unexpected "Label is too long"
else if ncs == 255
then unexpected "Host name is too long"
else go ls (ncs + 1) (c == '-')
(A.ascii c : lcs) (nlcs + 1)
Nothing → return $ HN $ BS.pack $ concat
$ reverse $ reverse lcs : ls
instance Printable (InetAddr HostName) where
print (InetAddr n p) = T.print n <> T.char7 ':' <> T.print p
instance Textual (InetAddr HostName) where
textual = InetAddr <$> T.textual <*> (P.char ':' *> T.textual)
hostNameLabels ∷ HostName → [ByteString]
hostNameLabels = BS.split (A.ascii '.') . hostName
arpaHostName ∷ IP → HostName
arpaHostName (IPv4 a) =
HN $ BS8.pack $ printf "%i.%i.%i.%i.in-addr.arpa" o4 o3 o2 o1
where (o1, o2, o3, o4) = ip4ToOctets a
arpaHostName (IPv6 a) =
HN $ BS8.pack $ digits (reverse $ ip6ToWordList a) ++ "ip6.arpa"
where digits (w : ws) = [d4, '.', d3, '.', d2, '.', d1, '.'] ++ digits ws
where d1 = toDigit $ w `shiftR` 12
d2 = toDigit $ w `shiftR` 8 .&. 0xF
d3 = toDigit $ w `shiftR` 4 .&. 0xF
d4 = toDigit $ w .&. 0xF
toDigit n | n < 10 = chr $ ord '0' + fromIntegral n
| otherwise = chr $ ord 'a' + fromIntegral n 10
digits [] = []
newtype Writer s α = Writer { runWriter ∷ (s → s) → (s → s, α) }
instance Functor (Writer s) where
fmap f m = Writer $ \append →
let (append', x) = runWriter m append
in (append', f x)
instance Applicative (Writer s) where
pure = return
(<*>) = ap
instance Monad (Writer s) where
return x = Writer $ \append → (append, x)
m >>= f = Writer $ \append →
let (append', x) = runWriter m append
in runWriter (f x) append'
newtype StateT k v μ α =
StateT { runStateT ∷ Map k v → Maybe Word16 → μ (Map k v, Maybe Word16, α) }
type CompT s α = StateT [ByteString] Word16 (Writer s) α
type DecompT μ α = StateT Word16 HostName μ α
compress ∷ Serializer s ⇒ Word16 → CompT s () → s
compress i m = append mempty
where (append, _) = runWriter (runStateT m Map.empty $ Just i) id
decompress ∷ Monad μ ⇒ Word16 → DecompT μ α → μ α
decompress i m = do
(_, _, x) ← runStateT m Map.empty $ Just i
return x
instance Monad μ ⇒ Functor (StateT k v μ) where
fmap f m = StateT $ \ptrs offset → do
(ptrs', offset', x) ← runStateT m ptrs offset
return (ptrs', offset', f x)
instance Monad μ ⇒ Applicative (StateT k v μ) where
pure = return
(<*>) = ap
instance Monad μ ⇒ Monad (StateT k v μ) where
return = lift . return
m >>= f = StateT $ \ptrs offset → do
(ptrs', offset', x) ← runStateT m ptrs offset
runStateT (f x) ptrs' offset'
fail msg = lift $ fail msg
lift ∷ Monad μ ⇒ μ α → StateT k v μ α
lift m = StateT $ \ptrs offset → do
x ← m
return (ptrs, offset, x)
write ∷ Serializer s ⇒ s → CompT s ()
write s = lift $ Writer $ \append → ((append s <>), ())
getOffset ∷ Monad μ ⇒ StateT k v μ (Maybe Word16)
getOffset = StateT $ \ptrs offset → return (ptrs, offset, offset)
addToOffset ∷ Word16 → Maybe Word16 → Maybe Word16
addToOffset n (Just i) | i' ← i + n, i' >= i && i' <= 0x3FFF = Just i'
addToOffset _ _ = Nothing
incOffset ∷ Monad μ ⇒ Word16 → StateT k v μ ()
incOffset n = StateT $ \ptrs offset →
return (ptrs, addToOffset n offset, ())
getEntries ∷ Monad μ ⇒ StateT k v μ (Map k v)
getEntries = StateT $ \ptrs offset → return (ptrs, offset, ptrs)
getEntry ∷ (Ord k, Monad μ) ⇒ k → StateT k v μ (Maybe v)
getEntry key = StateT $ \ptrs offset → do
return (ptrs, offset, Map.lookup key ptrs)
putEntry ∷ (Ord k, Monad μ) ⇒ k → v → StateT k v μ ()
putEntry key value = StateT $ \ptrs offset → do
return (Map.insert key value ptrs, offset, ())
evalComp ∷ Serializer s ⇒ CompT BB.Builder () → CompT s BSL.ByteString
evalComp m = StateT $ \ptrs offset → do
let (append, (ptrs', offset', _)) = runWriter (runStateT m ptrs offset) id
return (ptrs', offset', BB.toLazyByteString (append mempty))
evalDecomp ∷ Deserializer μ
⇒ Word16
→ DecompT D.BinaryDeserializer α
→ DecompT μ α
evalDecomp len m = StateT $ \ptrs offset → do
buf ← BSL.fromStrict <$> D.take (fromIntegral len)
let getM = D.binaryDeserializer $ runStateT m ptrs offset
case B.runGetOrFail getM buf of
Left (_, _, e) → unexpected e
Right (_, _, (ptrs', _, x)) → return (ptrs', addToOffset len offset, x)
serializeHostName ∷ Serializer s ⇒ HostName → CompT s ()
serializeHostName = go . hostNameLabels
where
go [] = do
write $ S.word8 0
incOffset 1
go labels@(label : labels') = do
entry ← getEntry labels
case entry of
Nothing → do
let ll = BS.length label
offset ← getOffset
write $ S.word8 (fromIntegral ll)
<> S.byteString label
incOffset $ 1 + fromIntegral ll
forM_ offset $ putEntry labels
go labels'
Just ptr → do
write $ S.word16B $ 0xC000 .|. ptr
incOffset 2
guard ∷ Deserializer μ ⇒ String → Bool → μ ()
guard msg test = unless test $ unexpected msg
deserializeHostName ∷ Deserializer μ ⇒ DecompT μ HostName
deserializeHostName = go []
where
folder suffix (label, offset) = do
forM_ offset $ \i → putEntry i (HN suffix')
return suffix'
where suffix' = BS.append label $ BS.cons (A.ascii '.') suffix
go labels = do
offset ← getOffset
w ← lift D.word8
incOffset 1
if w .&. 0xC0 == 0xC0
then do
w' ← lift D.word8
incOffset 1
let ptr = fromIntegral (w .&. 0x3F) `shiftL` 8 .|. fromIntegral w'
entry ← getEntry ptr
case entry of
Nothing → do
entries ← getEntries
lift $ unexpected $ "Invalid pointer " ++ show ptr
++ ": pointer map is "
++ show (Map.elems entries)
Just (HN suffix1) → HN <$> foldM folder suffix1 labels
else
if w == 0
then do
lift $ guard "Hostname with zero labels" $ not $ null labels
let (lastLabel, lastOffset) : labels' = labels
forM_ lastOffset $ \i → putEntry i (HN lastLabel)
HN <$> foldM folder lastLabel labels'
else do
lift $ guard "Label is too long" $ w <= 63
label ← lift $ D.take $ fromIntegral w
incOffset $ fromIntegral w
go ((BS.map A.toLower8 label, offset) : labels)
data HostAddr a = HostName !HostName
| HostAddr !a
deriving (Typeable, Show, Read, Eq, Ord)
type Host4Addr = HostAddr IP4
type Host6Addr = HostAddr IP6
aHostAddr ∷ Proxy HostAddr
aHostAddr = Proxy
aHostAddrOf ∷ Proxy a → Proxy (HostAddr a)
aHostAddrOf _ = Proxy
aHost4Addr ∷ Proxy Host4Addr
aHost4Addr = Proxy
aHost6Addr ∷ Proxy Host6Addr
aHost6Addr = Proxy
aHostAddrIP ∷ Proxy (HostAddr IP)
aHostAddrIP = Proxy
instance Printable a ⇒ Printable (HostAddr a) where
print (HostName name) = T.print name
print (HostAddr addr) = T.print addr
instance Textual a ⇒ Textual (HostAddr a) where
textual = P.try (HostName <$> T.textual)
<|> (HostAddr <$> T.textual)
instance Printable (InetAddr a) ⇒ Printable (InetAddr (HostAddr a)) where
print (InetAddr (HostName n) p) = T.print $ InetAddr n p
print (InetAddr (HostAddr a) p) = T.print $ InetAddr a p
instance Textual (InetAddr a) ⇒ Textual (InetAddr (HostAddr a)) where
textual = P.try (InetAddr <$> (HostName <$> T.textual)
<*> (P.char ':' *> T.textual))
<|> T.textual
type DnsId = Word16
data DnsType α where
AddrDnsType ∷ DnsType IP4
Addr6DnsType ∷ DnsType IP6
NsDnsType ∷ DnsType HostName
CNameDnsType ∷ DnsType HostName
PtrDnsType ∷ DnsType HostName
MxDnsType ∷ DnsType (Word16, HostName)
#if MIN_VERSION_base(4,7,0)
deriving instance Typeable DnsType
#else
deriving instance Typeable1 DnsType
#endif
deriving instance Eq (DnsType α)
instance Show (DnsType α) where
showsPrec _ AddrDnsType = showString "AddrDnsType"
showsPrec _ Addr6DnsType = showString "Addr6DnsType"
showsPrec _ NsDnsType = showString "NsDnsType"
showsPrec _ CNameDnsType = showString "CNameDnsType"
showsPrec _ PtrDnsType = showString "PtrDnsType"
showsPrec _ MxDnsType = showString "MxDnsType"
dnsTypeCode ∷ DnsType α → Word16
dnsTypeCode AddrDnsType = 1
dnsTypeCode Addr6DnsType = 28
dnsTypeCode NsDnsType = 2
dnsTypeCode CNameDnsType = 5
dnsTypeCode PtrDnsType = 12
dnsTypeCode MxDnsType = 15
data DnsData = ∀ α . DnsData { dnsType ∷ !(DnsType α)
, dnsData ∷ α
}
deriving Typeable
instance Show DnsData where
showsPrec p (DnsData {..}) = showParen (p > 10)
$ showString "DnsData {dnsType = "
. showsPrec (p + 1) dnsType
. showString ", dnsData = "
. case dnsType of
AddrDnsType → showsPrec p' dnsData
Addr6DnsType → showsPrec p' dnsData
NsDnsType → showsPrec p' dnsData
CNameDnsType → showsPrec p' dnsData
PtrDnsType → showsPrec p' dnsData
MxDnsType → showsPrec p' dnsData
. showString "}"
where p' = 10 ∷ Int
data DnsRecord = DnsRecord {
dnsRecOwner ∷ !HostName
,
dnsRecTtl ∷ !Word32
,
dnsRecData ∷ !DnsData
}
deriving (Typeable, Show)
serializeDnsRecord ∷ Serializer s ⇒ DnsRecord → CompT s ()
serializeDnsRecord (DnsRecord {..}) | DnsData tp dt ← dnsRecData = do
serializeHostName dnsRecOwner
write $ S.word16B (dnsTypeCode tp)
<> S.word16B 1
<> S.word32B dnsRecTtl
incOffset 10
d ← evalComp $ case tp of
AddrDnsType → write (S.put dt) >> incOffset 4
Addr6DnsType → write (S.put dt) >> incOffset 16
NsDnsType → serializeHostName dt
CNameDnsType → serializeHostName dt
PtrDnsType → serializeHostName dt
MxDnsType → do
write $ S.word16B $ fst dt
incOffset 2
serializeHostName $ snd dt
let len = fromIntegral (BSL.length d)
write $ S.word16B len
<> S.lazyByteString (BSL.take (fromIntegral len) d)
deserializeDnsRecord ∷ Deserializer μ ⇒ DecompT μ DnsRecord
deserializeDnsRecord = do
owner ← deserializeHostName
code ← lift D.word16B
void $ lift D.word16B
ttl ← lift D.word32B
len ← lift D.word16B
incOffset 10
dd ← evalDecomp len $ case code of
1 → DnsData AddrDnsType <$> (incOffset 4 >> lift D.get)
2 → DnsData NsDnsType <$> deserializeHostName
5 → DnsData CNameDnsType <$> deserializeHostName
12 → DnsData PtrDnsType <$> deserializeHostName
28 → DnsData Addr6DnsType <$> (incOffset 16 >> lift D.get)
_ → lift $ unexpected $ "Unsupported type " ++ show code
return $ DnsRecord owner ttl dd
data DnsQType = ∀ α . StdDnsType (DnsType α)
| AllDnsType
deriving Typeable
instance Show DnsQType where
showsPrec p (StdDnsType t) = showParen (p > 10)
$ showString "StdDnsType "
. showsPrec (p + 1) t
showsPrec _ AllDnsType = showString "AllDnsType"
dnsQTypeCode ∷ DnsQType → Word16
dnsQTypeCode (StdDnsType t) = dnsTypeCode t
dnsQTypeCode AllDnsType = 255
instance Eq DnsQType where
t1 == t2 = dnsQTypeCode t1 == dnsQTypeCode t2
instance Ord DnsQType where
t1 `compare` t2 = dnsQTypeCode t1 `compare` dnsQTypeCode t2
instance Serializable DnsQType where
put = S.word16B . dnsQTypeCode
instance SizedSerializable DnsQType where
size _ = 2
instance Deserializable DnsQType where
get = D.word16B >>= \case
1 → return $ StdDnsType AddrDnsType
2 → return $ StdDnsType NsDnsType
5 → return $ StdDnsType CNameDnsType
12 → return $ StdDnsType PtrDnsType
28 → return $ StdDnsType Addr6DnsType
255 → return AllDnsType
t → unexpected $ "Unsupported query type" ++ show t
data DnsQuestion = DnsQuestion {
dnsQName ∷ !HostName
,
dnsQType ∷ !DnsQType
}
deriving (Typeable, Show, Eq, Ord)
serializeDnsQuestion ∷ Serializer s ⇒ DnsQuestion → CompT s ()
serializeDnsQuestion (DnsQuestion {..}) = do
serializeHostName dnsQName
write $ S.put dnsQType
<> S.word16B 1
incOffset 4
deserializeDnsQuestion ∷ Deserializer μ ⇒ DecompT μ DnsQuestion
deserializeDnsQuestion = do
q ← DnsQuestion <$> deserializeHostName <*> lift D.get
lift $ D.word16B >>= guard "Unsupported class in a question" . (== 1)
incOffset 4
return q
data DnsReq
= DnsReq {
dnsReqId ∷ !DnsId
,
dnsReqTruncd ∷ !Bool
,
dnsReqRec ∷ !Bool
,
dnsReqQuestion ∷ !DnsQuestion
}
| DnsInvReq { dnsReqId ∷ !DnsId
,
dnsReqInv ∷ !IP
}
deriving (Typeable, Show)
anyHostName ∷ HostName
anyHostName = HN "any"
instance Serializable DnsReq where
put (DnsReq {..})
= S.word16B dnsReqId
<> S.word8 (if dnsReqRec then 1 else 0
.|. if dnsReqTruncd then 2 else 0)
<> S.word8 0
<> S.word16B 1
<> S.word16B 0
<> S.word16B 0
<> S.word16B 0
<> compress 12 (serializeDnsQuestion dnsReqQuestion)
put (DnsInvReq {..})
= S.word16B dnsReqId
<> S.word8 8
<> S.word8 0
<> S.word16B 0
<> S.word16B 1
<> S.word16B 0
<> S.word16B 0
<> compress 12 (serializeDnsRecord record)
where
record = DnsRecord { dnsRecOwner = anyHostName
, dnsRecTtl = 0
, dnsRecData = case dnsReqInv of
IPv4 a → DnsData AddrDnsType a
IPv6 a → DnsData Addr6DnsType a }
instance Deserializable DnsReq where
get = do
i ← D.word16B
w ← D.word8
void D.word8
guard "Not a request" $ w .&. 128 == 0
let rec = w .&. 1 /= 0
truncd = w .&. 2 /= 0
opcode = w `shiftR` 3 .&. 0xF
case opcode of
0 → do
D.word16B >>= guard "No questions in query" . (== 1)
D.word16B >>= guard "Answers in query" . (== 0)
D.word16B >>= guard "Authorities in query" . (== 0)
D.word16B >>= guard "Extras in query" . (== 0)
decompress 12 $ do
q ← deserializeDnsQuestion
return $ DnsReq { dnsReqId = i
, dnsReqTruncd = truncd
, dnsReqRec = rec
, dnsReqQuestion = q }
1 → do
D.word16B >>= guard "Questions in inverse query" . (== 0)
D.word16B >>= guard "No answers in inverse query" . (== 1)
D.word16B >>= guard "Authorities in inverse query" . (== 0)
D.word16B >>= guard "Extras in inverse query" . (== 0)
DnsRecord {dnsRecData} ← decompress 12 deserializeDnsRecord
case dnsRecData of
DnsData AddrDnsType a →
return $ DnsInvReq { dnsReqId = i, dnsReqInv = IPv4 a }
DnsData Addr6DnsType a →
return $ DnsInvReq { dnsReqId = i, dnsReqInv = IPv6 a }
_ → unexpected "Invalid answer RR in inverse query"
_ → unexpected $ "Invalid opcode " ++ show opcode ++ " in request"
data DnsError = FormatDnsError
| FailureDnsError
| NoNameDnsError
| NotImplDnsError
| RefusedDnsError
| NameExistsDnsError
| RsExistsDnsError
| NoRsDnsError
| NotAuthDnsError
| NotInZoneDnsError
deriving (Typeable, Show, Read, Eq, Ord, Enum)
dnsErrorCode ∷ DnsError → Word8
dnsErrorCode FormatDnsError = 1
dnsErrorCode FailureDnsError = 2
dnsErrorCode NoNameDnsError = 3
dnsErrorCode NotImplDnsError = 4
dnsErrorCode RefusedDnsError = 5
dnsErrorCode NameExistsDnsError = 6
dnsErrorCode RsExistsDnsError = 7
dnsErrorCode NoRsDnsError = 8
dnsErrorCode NotAuthDnsError = 9
dnsErrorCode NotInZoneDnsError = 10
data DnsResp
= DnsResp {
dnsRespId ∷ !DnsId
,
dnsRespTruncd ∷ !Bool
,
dnsRespAuthd ∷ !Bool
,
dnsRespRec ∷ !Bool
,
dnsRespQuestion ∷ !DnsQuestion
,
dnsRespAnswers ∷ [DnsRecord]
,
dnsRespAuths ∷ [DnsRecord]
,
dnsRespExtras ∷ [DnsRecord]
}
| DnsErrResp { dnsRespId ∷ !DnsId
,
dnsRespError ∷ !DnsError
}
deriving (Typeable, Show)
instance Serializable DnsResp where
put (DnsResp {..})
= S.word16B dnsRespId
<> S.word8 (128
.|. if dnsRespTruncd then 2 else 0
.|. if dnsRespAuthd then 4 else 0)
<> S.word8 (if dnsRespRec then 128 else 0)
<> S.word16B 1
<> S.word16B (fromIntegral $ length dnsRespAnswers)
<> S.word16B (fromIntegral $ length dnsRespAuths)
<> S.word16B (fromIntegral $ length dnsRespExtras)
<> compress 12 records
where records = do
serializeDnsQuestion dnsRespQuestion
forM_ dnsRespAnswers serializeDnsRecord
forM_ dnsRespAuths serializeDnsRecord
forM_ dnsRespExtras serializeDnsRecord
put (DnsErrResp {..})
= S.word16B dnsRespId
<> S.word8 8
<> S.word8 (fromIntegral $ dnsErrorCode dnsRespError)
<> S.word16B 0
<> S.word16B 0
<> S.word16B 0
<> S.word16B 0
instance Deserializable DnsResp where
get = do
i ← D.word16B
w ← D.word8
guard "Not a response" $ w .&. 128 /= 0
w' ← D.word8
let truncd = w .&. 2 /= 0
authd = w .&. 4 /= 0
rec = w' .&. 128 /= 0
ec = w' .&. 0xF
case ec of
0 → do
D.word16B >>= guard "No question in a response" . (== 1)
anc ← D.word16B
nsc ← D.word16B
arc ← D.word16B
decompress 12 $ do
q ← deserializeDnsQuestion
ans ← mapM (const deserializeDnsRecord) [1 .. anc]
nss ← mapM (const deserializeDnsRecord) [1 .. nsc]
ars ← mapM (const deserializeDnsRecord) [1 .. arc]
return $ DnsResp { dnsRespId = i
, dnsRespTruncd = truncd
, dnsRespAuthd = authd
, dnsRespRec = rec
, dnsRespQuestion = q
, dnsRespAnswers = ans
, dnsRespAuths = nss
, dnsRespExtras = ars }
_ → do
void D.word16B
void D.word16B
void D.word16B
void D.word16B
DnsErrResp i <$> case ec of
1 → return FormatDnsError
2 → return FailureDnsError
3 → return NoNameDnsError
4 → return NotImplDnsError
5 → return RefusedDnsError
6 → return NameExistsDnsError
7 → return RsExistsDnsError
8 → return NoRsDnsError
9 → return NotAuthDnsError
10 → return NotInZoneDnsError
_ → unexpected $ "Unknown error code " ++ show ec