{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MultiWayIf #-}

module Hans.Message.Dns (
    DNSPacket(..)
  , DNSHeader(..)
  , OpCode(..)
  , RespCode(..)
  , Query(..)
  , QClass(..)
  , QType(..)
  , RR(..)
  , Type(..)
  , Class(..)
  , RData(..)
  , Name

  , parseDNSPacket,  getDNSPacket
  , renderDNSPacket, putDNSPacket
  ) where

import Hans.Address.IP4
import Hans.Utils (chunk)

import Control.Monad
import Data.Bits
import Data.Foldable ( traverse_, foldMap )
import Data.Int
import Data.Serialize ( Putter, runPut, putWord8, putWord16be, putWord32be
                      , putByteString )
import Data.Word
import MonadLib ( lift, StateT, runStateT, get, set )
import Numeric ( showHex )

import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict as Map
import qualified Data.Serialize.Get as C


-- DNS Packets -----------------------------------------------------------------

data DNSPacket = DNSPacket { dnsHeader            :: DNSHeader
                           , dnsQuestions         :: [Query]
                           , dnsAnswers           :: [RR]
                           , dnsAuthorityRecords  :: [RR]
                           , dnsAdditionalRecords :: [RR]
                           } deriving (Show)

data DNSHeader = DNSHeader { dnsId     :: !Word16
                           , dnsQuery  :: Bool
                           , dnsOpCode :: OpCode
                           , dnsAA     :: Bool
                           , dnsTC     :: Bool
                           , dnsRD     :: Bool
                           , dnsRA     :: Bool
                           , dnsRC     :: RespCode
                           } deriving (Show)

data OpCode = OpQuery
            | OpIQuery
            | OpStatus
            | OpReserved !Word16
              deriving (Show)

data RespCode = RespNoError
              | RespFormatError
              | RespServerFailure
              | RespNameError
              | RespNotImplemented
              | RespRefused
              | RespReserved !Word16
                deriving (Eq,Show)

type Name = [S.ByteString]

data Query = Query { qName  :: Name
                   , qType  :: QType
                   , qClass :: QClass
                   } deriving (Show)

data RR = RR { rrName  :: Name
             , rrClass :: Class
             , rrTTL   :: !Int32
             , rrRData :: RData
             } deriving (Show)

data QType = QType Type
           | AFXR
           | MAILB
           | MAILA
           | QTAny
             deriving (Show)

data Type = A
          | NS
          | MD
          | MF
          | CNAME
          | SOA
          | MB
          | MG
          | MR
          | NULL
          | PTR
          | HINFO
          | MINFO
          | MX
          | AAAA
            deriving (Show)

data QClass = QClass Class
            | QAnyClass
              deriving (Show)

data Class = IN | CS | CH | HS
             deriving (Show,Eq)

data RData = RDA IP4
           | RDNS Name
           | RDMD Name
           | RDMF Name
           | RDCNAME Name
           | RDSOA Name Name !Word32 !Int32 !Int32 !Int32 !Word32
           | RDMB Name
           | RDMG Name
           | RDMR Name
           | RDPTR Name
           | RDHINFO S.ByteString S.ByteString
           | RDMINFO Name Name
           | RDMX !Word16 Name
           | RDNULL S.ByteString
           | RDUnknown Type S.ByteString
             deriving (Show)


-- Cereal With Label Compression -----------------------------------------------

data RW = RW { rwOffset :: !Int
             , rwLabels :: Map.Map Int Name
             } deriving (Show)

type Get = StateT RW C.Get

{-# INLINE unGet #-}
unGet :: Get a -> C.Get a
unGet m =
  do (a,_) <- runStateT RW { rwOffset = 0, rwLabels = Map.empty } m
     return a

getOffset :: Get Int
getOffset  = rwOffset `fmap` get

addOffset :: Int -> Get ()
addOffset off =
  do rw <- get
     set $! rw { rwOffset = rwOffset rw + off }

lookupPtr :: Int -> Get Name
lookupPtr off =
  do rw <- get
     when (off >= rwOffset rw) (fail "Invalid offset in pointer")
     case Map.lookup off (rwLabels rw) of
       Just ls -> return ls
       Nothing -> fail $ "Unknown label for offset: " ++ showHex off "\n"
                      ++ show (rwLabels rw)

data Label = Label Int S.ByteString
           | Ptr Int Name
             deriving (Show)

labelsToName :: [Label] -> Name
labelsToName  = foldMap toName
  where
  toName (Label _ l) = [l]
  toName (Ptr _ n)   = n

addLabels :: [Label] -> Get ()
addLabels labels =
  do rw <- get
     set $! rw { rwLabels = Map.fromList newLabels `Map.union` rwLabels rw }
  where
  newLabels = go labels (labelsToName labels)

  go (Label off _ : rest) name@(_ : ns) = (off,name) : go rest ns
  go (Ptr off _   : _)    name          = [(off,name)]
  go _                    _             = []


{-# INLINE liftGet #-}
liftGet :: Int -> C.Get a -> Get a
liftGet n m = do addOffset n
                 lift m


{-# INLINE getWord8 #-}
getWord8 :: Get Word8
getWord8  = liftGet 1 C.getWord8

{-# INLINE getWord16be #-}
getWord16be :: Get Word16
getWord16be  = liftGet 2 C.getWord16be

{-# INLINE getWord32be #-}
getWord32be :: Get Word32
getWord32be  = liftGet 4 C.getWord32be

{-# INLINE getInt32be #-}
getInt32be :: Get Int32
getInt32be  = fromIntegral `fmap` liftGet 4 C.getWord32be

{-# INLINE getBytes #-}
getBytes :: Int -> Get S.ByteString
getBytes n = liftGet n (C.getBytes n)

isolate :: Int -> Get a -> Get a
isolate n body =
  do off      <- get
     (a,off') <- lift (C.isolate n (runStateT off body))
     set off'
     return a

label :: String -> Get a -> Get a
label str m =
  do off      <- get
     (a,off') <- lift (C.label str (runStateT off m))
     set off'
     return a


{-# INLINE putInt32be #-}
putInt32be :: Putter Int32
putInt32be i = putWord32be (fromIntegral i)

-- Parsing ---------------------------------------------------------------------

parseDNSPacket :: S.ByteString -> Either String DNSPacket
parseDNSPacket  = C.runGet getDNSPacket

getDNSPacket :: C.Get DNSPacket
getDNSPacket  = unGet $ label "DNSPacket" $
  do dnsHeader <- getDNSHeader
     qdCount   <- getWord16be
     anCount   <- getWord16be
     nsCount   <- getWord16be
     arCount   <- getWord16be

     let blockOf c l m = label l (replicateM (fromIntegral c) m)
     dnsQuestions         <- blockOf qdCount "Questions"          getQuery
     dnsAnswers           <- blockOf anCount "Answers"            getRR
     dnsAuthorityRecords  <- blockOf nsCount "Authority Records"  getRR
     dnsAdditionalRecords <- blockOf arCount "Additional Records" getRR

     return DNSPacket { .. }

--                                  1  1  1  1  1  1
--    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                      ID                       |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                    QDCOUNT                    |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                    ANCOUNT                    |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                    NSCOUNT                    |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                    ARCOUNT                    |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
getDNSHeader :: Get DNSHeader
getDNSHeader  = label "DNS Header" $
  do dnsId <- getWord16be
     flags <- getWord16be
     let dnsQuery  = not (flags `testBit` 15)
         dnsOpCode = parseOpCode (flags `shiftR` 11)
         dnsAA     = flags `testBit` 10
         dnsTC     = flags `testBit`  9
         dnsRD     = flags `testBit`  8
         dnsRA     = flags `testBit`  7
         dnsZ      = (flags `shiftR` 4) .&. 0x7
         dnsRC     = parseRespCode (flags .&. 0xf)

     unless (dnsZ == 0) (fail ("Z not zero"))

     return DNSHeader { .. }


parseOpCode :: Word16 -> OpCode
parseOpCode 0 = OpQuery
parseOpCode 1 = OpIQuery
parseOpCode 2 = OpStatus
parseOpCode c = OpReserved (c .&. 0xf)

parseRespCode :: Word16 -> RespCode
parseRespCode 0 = RespNoError
parseRespCode 1 = RespFormatError
parseRespCode 2 = RespServerFailure
parseRespCode 3 = RespNameError
parseRespCode 4 = RespNotImplemented
parseRespCode 5 = RespRefused
parseRespCode c = RespReserved (c .&. 0xf)

getQuery :: Get Query
getQuery  = label "Question" $
  do qName  <- getName
     qType  <- label "QTYPE"  getQType
     qClass <- label "QCLASS" getQClass
     return Query { .. }

--                                  1  1  1  1  1  1
--    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                                               |
--  /                                               /
--  /                      NAME                     /
--  |                                               |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                      TYPE                     |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                     CLASS                     |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                      TTL                      |
--  |                                               |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
--  |                   RDLENGTH                    |
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
--  /                     RDATA                     /
--  /                                               /
--  +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
getRR :: Get RR
getRR  = label "RR" $
  do rrName  <- getName
     ty      <- getType
     rrClass <- getClass
     rrTTL   <- getInt32be
     rrRData <- getRData ty
     return RR { .. }


getType :: Get Type
getType  =
  do qt <- getQType
     case qt of
       QType ty -> return ty
       _        -> fail ("Invalid TYPE: " ++ show qt)

getQType :: Get QType
getQType  =
  do tag <- getWord16be
     case tag of
       1   -> return (QType A)
       2   -> return (QType NS)
       3   -> return (QType MD)
       4   -> return (QType MF)
       5   -> return (QType CNAME)
       6   -> return (QType SOA)
       7   -> return (QType MB)
       8   -> return (QType MG)
       9   -> return (QType MR)
       10  -> return (QType NULL)
       12  -> return (QType PTR)
       13  -> return (QType HINFO)
       14  -> return (QType MINFO)
       15  -> return (QType MX)
       28  -> return (QType AAAA)
       252 -> return AFXR
       253 -> return MAILB
       254 -> return MAILA
       255 -> return QTAny
       _   -> fail ("Invalid TYPE: " ++ show tag)


getQClass :: Get QClass
getQClass  =
  do tag <- getWord16be
     case tag of
       1   -> return (QClass IN)
       2   -> return (QClass CS)
       3   -> return (QClass CH)
       4   -> return (QClass HS)
       255 -> return QAnyClass
       _   -> fail ("Invalid CLASS: " ++ show tag)

getName :: Get Name
getName  =
  do labels <- go
     addLabels labels
     return (labelsToName labels)
  where
  go = do off <- getOffset
          len <- getWord8
          if | len .&. 0xc0 == 0xc0 ->
               do l <- getWord8
                  let ptr = fromIntegral ((0x3f .&. len) `shiftL` 8)
                          + fromIntegral l
                  ns <- lookupPtr ptr
                  return [Ptr off ns]

             | len == 0 ->
                  return []

             | otherwise ->
               do l  <- getBytes (fromIntegral len)
                  ls <- go
                  return (Label off l:ls)

getClass :: Get Class
getClass  = label "CLASS" $
  do qc <- getQClass
     case qc of
       QClass c  -> return c
       QAnyClass -> fail "Invalid CLASS"

getRData :: Type -> Get RData
getRData ty = label (show ty) $
  do len <- getWord16be
     isolate (fromIntegral len) $ case ty of
       A     -> RDA  `fmap` liftGet 4 parseIP4
       NS    -> RDNS `fmap` getName
       MD    -> RDMD `fmap` getName
       MF    -> RDMF `fmap` getName
       CNAME -> RDCNAME `fmap` getName
       SOA   -> do mname   <- getName
                   rname   <- getName
                   serial  <- getWord32be
                   refresh <- getInt32be
                   retry   <- getInt32be
                   expire  <- getInt32be
                   minTTL  <- getWord32be
                   return (RDSOA mname rname serial refresh retry expire minTTL)
       MB    -> RDMB `fmap` getName
       MG    -> RDMG `fmap` getName
       MR    -> RDMR `fmap` getName
       NULL  -> RDNULL `fmap` (getBytes =<< lift C.remaining)
       PTR   -> RDPTR `fmap` getName

       HINFO -> do cpuLen <- getWord8
                   cpu    <- getBytes (fromIntegral cpuLen)
                   osLen  <- getWord8
                   os     <- getBytes (fromIntegral osLen)
                   return (RDHINFO cpu os)

       MINFO -> do rmailBx <- getName
                   emailBx <- getName
                   return (RDMINFO rmailBx emailBx)

       MX    -> do pref <- getWord16be
                   ex   <- getName
                   return (RDMX pref ex)

       _     -> RDUnknown ty `fmap` (getBytes =<< lift C.remaining)


-- Rendering -------------------------------------------------------------------

renderDNSPacket :: DNSPacket -> L.ByteString
renderDNSPacket pkt = chunk (runPut (putDNSPacket pkt))

putDNSPacket :: Putter DNSPacket
putDNSPacket DNSPacket{ .. } =
  do putDNSHeader dnsHeader

     putWord16be (fromIntegral (length dnsQuestions))
     putWord16be (fromIntegral (length dnsAnswers))
     putWord16be (fromIntegral (length dnsAuthorityRecords))
     putWord16be (fromIntegral (length dnsAdditionalRecords))

     traverse_ putQuery dnsQuestions
     traverse_ putRR dnsAnswers
     traverse_ putRR dnsAuthorityRecords
     traverse_ putRR dnsAdditionalRecords

putDNSHeader :: Putter DNSHeader
putDNSHeader DNSHeader { .. } =
  do putWord16be dnsId
     let flag i b w | b         = setBit w i
                    | otherwise = clearBit w i
         flags = flag 15 (not dnsQuery)
               $ flag 10 dnsAA
               $ flag  9 dnsTC
               $ flag  8 dnsRD
               $ flag  7 dnsRA
               $ flag  4 False -- dnsZ
               $ (renderOpCode dnsOpCode `shiftL` 11) .|. renderRespCode dnsRC
     putWord16be flags

renderOpCode :: OpCode -> Word16
renderOpCode OpQuery        = 0
renderOpCode OpIQuery       = 1
renderOpCode OpStatus       = 2
renderOpCode (OpReserved c) = c .&. 0xf

renderRespCode :: RespCode -> Word16
renderRespCode RespNoError        = 0
renderRespCode RespFormatError    = 1
renderRespCode RespServerFailure  = 2
renderRespCode RespNameError      = 3
renderRespCode RespNotImplemented = 4
renderRespCode RespRefused        = 5
renderRespCode (RespReserved c)   = c .&. 0xf

putName :: Putter Name
putName  = go
  where
  go (l:ls)
    | S.null l        = putWord8 0
    | S.length l > 63 = error "Label too big"
    | otherwise       = do putWord8 (fromIntegral len)
                           putByteString l
                           go ls
    where
    len = S.length l

  go [] = putWord8 0

putQuery :: Putter Query
putQuery Query { .. } =
  do putName   qName
     putQType  qType
     putQClass qClass

putType :: Putter Type
putType A     = putWord16be 1
putType NS    = putWord16be 2
putType MD    = putWord16be 3
putType MF    = putWord16be 4
putType CNAME = putWord16be 5
putType SOA   = putWord16be 6
putType MB    = putWord16be 7
putType MG    = putWord16be 8
putType MR    = putWord16be 9
putType NULL  = putWord16be 10
putType PTR   = putWord16be 12
putType HINFO = putWord16be 13
putType MINFO = putWord16be 14
putType MX    = putWord16be 15
putType AAAA  = putWord16be 28

putQType :: Putter QType
putQType (QType ty) = putType ty
putQType AFXR       = putWord16be 252
putQType MAILB      = putWord16be 253
putQType MAILA      = putWord16be 254
putQType QTAny      = putWord16be 255

putQClass :: Putter QClass
putQClass (QClass c) = putClass c
putQClass QAnyClass  = putWord16be 255

putRR :: Putter RR
putRR RR { .. } =
  do putName rrName
     let (ty,rdata) = putRData rrRData
     putType ty
     putClass rrClass
     putWord32be (fromIntegral rrTTL)
     putWord16be (fromIntegral (S.length rdata))
     putByteString rdata

putClass :: Putter Class
putClass IN = putWord16be 1
putClass CS = putWord16be 2
putClass CH = putWord16be 3
putClass HS = putWord16be 4

putRData :: RData -> (Type,S.ByteString)
putRData rd = case rd of
  RDA addr           -> rdata A     (renderIP4 addr)
  RDNS name          -> rdata NS    (putName name)
  RDMD name          -> rdata MD    (putName name)
  RDMF name          -> rdata MF    (putName name)
  RDCNAME name       -> rdata CNAME (putName name)
  RDSOA m r s f t ex ttl ->
                        rdata SOA $ do putName m
                                       putName r
                                       putWord32be s
                                       putInt32be f
                                       putInt32be t
                                       putInt32be ex
                                       putWord32be ttl

  RDMB name          -> rdata MB    (putName name)
  RDMG name          -> rdata MG    (putName name)
  RDMR name          -> rdata MR    (putName name)
  RDNULL bytes       -> rdata NULL  $ do putWord8 (fromIntegral (S.length bytes))
                                         putByteString bytes
  RDPTR name         -> rdata PTR   (putName name)
  RDHINFO cpu os     -> rdata HINFO $ do putWord8 (fromIntegral (S.length cpu))
                                         putByteString cpu
                                         putWord8 (fromIntegral (S.length os))
                                         putByteString os
  RDMINFO rm em      -> rdata MINFO $ do putName rm
                                         putName em
  RDMX pref ex       -> rdata MX    $ do putWord16be pref
                                         putName ex
  RDUnknown ty bytes -> (ty,bytes)
  where
  rdata tag m = (tag,runPut m)