{-# LANGUAGE OverloadedStrings #-}

module Network.DNS.IO (
    -- * Receiving DNS messages
    receive
  , receiveFrom
  , receiveVC
    -- * Sending pre-encoded messages
  , send
  , sendTo
  , sendVC
  , sendAll
    -- ** Encoding queries for transmission
  , encodeQuestion
  , encodeVC
    -- ** Creating query response messages
  , responseA
  , responseAAAA
  ) where

import qualified Control.Exception as E
import qualified Data.ByteString as B
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.IP (IPv4, IPv6)
import Time.System (timeCurrent)
import Time.Types (Elapsed(..), Seconds(..))
import Network.Socket (Socket, SockAddr)
import Network.Socket.ByteString (recv, recvFrom)
import qualified Network.Socket.ByteString as Socket
import System.IO.Error

import Network.DNS.Decode (decodeAt)
import Network.DNS.Encode (encode)
import Network.DNS.Imports
import Network.DNS.Types.Internal

----------------------------------------------------------------

-- | Receive and decode a single 'DNSMessage' from a UDP 'Socket', throwing away
-- the client address.  Messages longer than 'maxUdpSize' are silently
-- truncated, but this should not occur in practice, since we cap the advertised
-- EDNS UDP buffer size limit at the same value.  A 'DNSError' is raised if I/O
-- or message decoding fails.
--
receive :: Socket -> IO DNSMessage
receive :: Socket -> IO DNSMessage
receive Socket
sock = do
    let bufsiz :: Int
bufsiz = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxUdpSize
    ByteString
bs <- Socket -> Int -> IO ByteString
recv Socket
sock Int
bufsiz forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \IOException
e -> forall e a. Exception e => e -> IO a
E.throwIO forall a b. (a -> b) -> a -> b
$ IOException -> DNSError
NetworkFailure IOException
e
    Elapsed (Seconds Int64
now) <- IO Elapsed
timeCurrent
    case Int64 -> ByteString -> Either DNSError DNSMessage
decodeAt Int64
now ByteString
bs of
        Left  DNSError
e   -> forall e a. Exception e => e -> IO a
E.throwIO DNSError
e
        Right DNSMessage
msg -> forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
msg

-- | Receive and decode a single 'DNSMessage' from a UDP 'Socket'.  Messages
-- longer than 'maxUdpSize' are silently truncated, but this should not occur
-- in practice, since we cap the advertised EDNS UDP buffer size limit at the
-- same value.  A 'DNSError' is raised if I/O or message decoding fails.
--
receiveFrom :: Socket -> IO (DNSMessage, SockAddr)
receiveFrom :: Socket -> IO (DNSMessage, SockAddr)
receiveFrom Socket
sock = do
    let bufsiz :: Int
bufsiz = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxUdpSize
    (ByteString
bs, SockAddr
client) <- Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
sock Int
bufsiz forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \IOException
e -> forall e a. Exception e => e -> IO a
E.throwIO forall a b. (a -> b) -> a -> b
$ IOException -> DNSError
NetworkFailure IOException
e
    Elapsed (Seconds Int64
now) <- IO Elapsed
timeCurrent
    case Int64 -> ByteString -> Either DNSError DNSMessage
decodeAt Int64
now ByteString
bs of
        Left  DNSError
e   -> forall e a. Exception e => e -> IO a
E.throwIO DNSError
e
        Right DNSMessage
msg -> forall (m :: * -> *) a. Monad m => a -> m a
return (DNSMessage
msg, SockAddr
client)

-- | Receive and decode a single 'DNSMesage' from a virtual-circuit (TCP).  It
-- is up to the caller to implement any desired timeout. An 'DNSError' is
-- raised if I/O or message decoding fails.
--
receiveVC :: Socket -> IO DNSMessage
receiveVC :: Socket -> IO DNSMessage
receiveVC Socket
sock = do
    Int
len <- forall {a}. Num a => ByteString -> a
toLen forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
2
    ByteString
bs <- Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
len
    Elapsed (Seconds Int64
now) <- IO Elapsed
timeCurrent
    case Int64 -> ByteString -> Either DNSError DNSMessage
decodeAt Int64
now ByteString
bs of
        Left DNSError
e    -> forall e a. Exception e => e -> IO a
E.throwIO DNSError
e
        Right DNSMessage
msg -> forall (m :: * -> *) a. Monad m => a -> m a
return DNSMessage
msg
  where
    toLen :: ByteString -> a
toLen ByteString
bs = case ByteString -> [Word8]
B.unpack ByteString
bs of
        [Word8
hi, Word8
lo] -> a
256 forall a. Num a => a -> a -> a
* (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
hi) forall a. Num a => a -> a -> a
+ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
lo)
        [Word8]
_        -> a
0              -- never reached

recvDNS :: Socket -> Int -> IO ByteString
recvDNS :: Socket -> Int -> IO ByteString
recvDNS Socket
sock Int
len = IO ByteString
recv1 forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \IOException
e -> forall e a. Exception e => e -> IO a
E.throwIO forall a b. (a -> b) -> a -> b
$ IOException -> DNSError
NetworkFailure IOException
e
  where
    recv1 :: IO ByteString
recv1 = do
        ByteString
bs1 <- Int -> IO ByteString
recvCore Int
len
        if ByteString -> Int
BS.length ByteString
bs1 forall a. Eq a => a -> a -> Bool
== Int
len then
            forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs1
          else do
            ByteString -> IO ByteString
loop ByteString
bs1
    loop :: ByteString -> IO ByteString
loop ByteString
bs0 = do
        let left :: Int
left = Int
len forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bs0
        ByteString
bs1 <- Int -> IO ByteString
recvCore Int
left
        let bs :: ByteString
bs = ByteString
bs0 ByteString -> ByteString -> ByteString
`BS.append` ByteString
bs1
        if ByteString -> Int
BS.length ByteString
bs forall a. Eq a => a -> a -> Bool
== Int
len then
            forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
          else
            ByteString -> IO ByteString
loop ByteString
bs
    eofE :: IOException
eofE = IOErrorType
-> String -> Maybe Handle -> Maybe String -> IOException
mkIOError IOErrorType
eofErrorType String
"connection terminated" forall a. Maybe a
Nothing forall a. Maybe a
Nothing
    recvCore :: Int -> IO ByteString
recvCore Int
len0 = do
        ByteString
bs <- Socket -> Int -> IO ByteString
recv Socket
sock Int
len0
        if ByteString
bs forall a. Eq a => a -> a -> Bool
== ByteString
"" then
            forall e a. Exception e => e -> IO a
E.throwIO IOException
eofE
          else
            forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

-- | Send an encoded 'DNSMessage' datagram over UDP.  The message length is
-- implicit in the size of the UDP datagram.  With TCP you must use 'sendVC',
-- because TCP does not have message boundaries, and each message needs to be
-- prepended with an explicit length.  The socket must be explicitly connected
-- to the destination nameserver.
--
send :: Socket -> ByteString -> IO ()
send :: Socket -> ByteString -> IO ()
send = (forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
.)forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO Int
Socket.send
{-# INLINE send #-}

-- | Send an encoded 'DNSMessage' datagram over UDP to a given address.  The
-- message length is implicit in the size of the UDP datagram.  With TCP you
-- must use 'sendVC', because TCP does not have message boundaries, and each
-- message needs to be prepended with an explicit length.
--
sendTo :: Socket -> ByteString -> SockAddr -> IO ()
sendTo :: Socket -> ByteString -> SockAddr -> IO ()
sendTo Socket
sock ByteString
str SockAddr
addr = Socket -> ByteString -> SockAddr -> IO Int
Socket.sendTo Socket
sock ByteString
str SockAddr
addr forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE sendTo #-}

-- | Send a single encoded 'DNSMessage' over TCP.  An explicit length is
-- prepended to the encoded buffer before transmission.  If you want to
-- send a batch of multiple encoded messages back-to-back over a single
-- TCP connection, and then loop to collect the results, use 'encodeVC'
-- to prefix each message with a length, and then use 'sendAll' to send
-- a concatenated batch of the resulting encapsulated messages.
--
sendVC :: Socket -> ByteString -> IO ()
sendVC :: Socket -> ByteString -> IO ()
sendVC = (forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
encodeVC)forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO ()
sendAll
{-# INLINE sendVC #-}

-- | Send one or more encoded 'DNSMessage' buffers over TCP, each allready
-- encapsulated with an explicit length prefix (perhaps via 'encodeVC') and
-- then concatenated into a single buffer.  DO NOT use 'sendAll' with UDP.
--
sendAll :: Socket -> BS.ByteString -> IO ()
sendAll :: Socket -> ByteString -> IO ()
sendAll = Socket -> ByteString -> IO ()
Socket.sendAll
{-# INLINE sendAll #-}

-- | The encoded 'DNSMessage' has the specified request ID.  The default values
-- of the RD, AD, CD and DO flag bits, as well as various EDNS features, can be
-- adjusted via the 'QueryControls' parameter.
--
-- The caller is responsible for generating the ID via a securely seeded
-- CSPRNG.
--
encodeQuestion :: Identifier     -- ^ Crypto random request id
                -> Question      -- ^ Query name and type
                -> QueryControls -- ^ Query flag and EDNS overrides
                -> ByteString
encodeQuestion :: Word16 -> Question -> QueryControls -> ByteString
encodeQuestion Word16
idt Question
q QueryControls
ctls = DNSMessage -> ByteString
encode forall a b. (a -> b) -> a -> b
$ Word16 -> Question -> QueryControls -> DNSMessage
makeQuery Word16
idt Question
q QueryControls
ctls

-- | Encapsulate an encoded 'DNSMessage' buffer for transmission over a TCP
-- virtual circuit.  With TCP the buffer needs to start with an explicit
-- length (the length is implicit with UDP).
--
encodeVC :: ByteString -> ByteString
encodeVC :: ByteString -> ByteString
encodeVC ByteString
legacyQuery =
    let len :: ByteString
len = ByteString -> ByteString
LBS.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BB.toLazyByteString forall a b. (a -> b) -> a -> b
$ Int16 -> Builder
BB.int16BE forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
legacyQuery
    in ByteString
len forall a. Semigroup a => a -> a -> a
<> ByteString
legacyQuery
{-# INLINE encodeVC #-}

----------------------------------------------------------------

-- | Compose a response with a single IPv4 RRset.  If the query
-- had an EDNS pseudo-header, a suitable EDNS pseudo-header must
-- be added to the response message, or else a 'FormatErr' response
-- must be sent.  The response TTL defaults to 300 seconds, and
-- should be updated (to the same value across all the RRs) if some
-- other TTL value is more appropriate.
--
responseA :: Identifier -> Question -> [IPv4] -> DNSMessage
responseA :: Word16 -> Question -> [IPv4] -> DNSMessage
responseA Word16
idt Question
q [IPv4]
ips = Word16 -> Question -> Answers -> DNSMessage
makeResponse Word16
idt Question
q Answers
as
  where
    dom :: ByteString
dom = Question -> ByteString
qname Question
q
    as :: Answers
as  = ByteString -> TYPE -> Word16 -> TTL -> RData -> ResourceRecord
ResourceRecord ByteString
dom TYPE
A Word16
classIN TTL
300 forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPv4 -> RData
RD_A forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IPv4]
ips

-- | Compose a response with a single IPv6 RRset.  If the query
-- had an EDNS pseudo-header, a suitable EDNS pseudo-header must
-- be added to the response message, or else a 'FormatErr' response
-- must be sent.  The response TTL defaults to 300 seconds, and
-- should be updated (to the same value across all the RRs) if some
-- other TTL value is more appropriate.
--
responseAAAA :: Identifier -> Question -> [IPv6] -> DNSMessage
responseAAAA :: Word16 -> Question -> [IPv6] -> DNSMessage
responseAAAA Word16
idt Question
q [IPv6]
ips = Word16 -> Question -> Answers -> DNSMessage
makeResponse Word16
idt Question
q Answers
as
  where
    dom :: ByteString
dom = Question -> ByteString
qname Question
q
    as :: Answers
as  = ByteString -> TYPE -> Word16 -> TTL -> RData -> ResourceRecord
ResourceRecord ByteString
dom TYPE
AAAA Word16
classIN TTL
300 forall b c a. (b -> c) -> (a -> b) -> a -> c
. IPv6 -> RData
RD_AAAA forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [IPv6]
ips