{-# LANGUAGE OverloadedStrings, DeriveDataTypeable #-}

module Network.DNS.Decode (
    decode
  , receive
  , receive'
  ) where

import Control.Applicative ((<$), (<$>), (<*), (<*>))
import Control.Monad (replicateM)
import Control.Monad.Trans.Resource (ResourceT, runResourceT)
import qualified Control.Exception as ControlException
import Data.Bits ((.&.), shiftR, testBit)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (($$), Source)
import Data.Conduit.Network (sourceSocket)
import Data.IP (toIPv4, toIPv6)
import Data.Traversable (traverse)
import Data.Typeable (Typeable)
import Network (Socket)
import Network.DNS.Internal
import Network.DNS.StateBinary

----------------------------------------------------------------


data RDATAParseError = RDATAParseError String
 deriving (Show, Typeable)

instance ControlException.Exception RDATAParseError


-- | Receiving DNS data from 'Socket' and parse it.

receive :: Socket -> IO DNSFormat
receive sock = do
    dns <- receiveDNSFormat $ sourceSocket sock
    case traverse unpackBytes dns of
        Left e  -> ControlException.throwIO (RDATAParseError e)
        Right d -> return d

-- | Receiving DNS data from 'Socket' and partially parse it.
--   Unknown RDATA sections will be left as 'ByteString'
receive' :: Socket -> IO (DNSMessage (RD ByteString))
receive' sock = receiveDNSFormat $ sourceSocket sock


----------------------------------------------------------------

-- | Parsing DNS data.

decode :: BL.ByteString -> Either String DNSFormat
decode bs = fst <$> runSGet decodeResponse bs
            >>= traverse unpackBytes

unpackBytes :: RD ByteString -> Either String (RD [Int])
unpackBytes (RD_OTH dta) = RD_OTH . fst <$> unpack dta'
 where
    len = fromIntegral $ BS.length dta
    dta' = BL.fromChunks [dta]
    unpack = runSGet (getNBytes len)
unpackBytes rd = Right $ error "unhandled case in decode" <$> rd

----------------------------------------------------------------
receiveDNSFormat :: Source (ResourceT IO) ByteString -> IO (DNSMessage (RD ByteString))
receiveDNSFormat src = fst <$> runResourceT (src $$ sink)
  where
    sink = sinkSGet decodeResponse

----------------------------------------------------------------

decodeResponse :: SGet (DNSMessage (RD ByteString))
decodeResponse = do
    hd <- decodeHeader
    DNSFormat hd <$> decodeQueries (qdCount hd)
                 <*> decodeRRs (anCount hd)
                 <*> decodeRRs (nsCount hd)
                 <*> decodeRRs (arCount hd)

----------------------------------------------------------------

decodeFlags :: SGet DNSFlags
decodeFlags = toFlags <$> get16
  where
    toFlags flgs = DNSFlags (getQorR flgs)
                            (getOpcode flgs)
                            (getAuthAnswer flgs)
                            (getTrunCation flgs)
                            (getRecDesired flgs)
                            (getRecAvailable flgs)
                            (getRcode flgs)
    getQorR w = if testBit w 15 then QR_Response else QR_Query
    getOpcode w = toEnum $ fromIntegral $ shiftR w 11 .&. 0x0f
    getAuthAnswer w = testBit w 10
    getTrunCation w = testBit w 9
    getRecDesired w = testBit w 8
    getRecAvailable w = testBit w 7
    getRcode w = toEnum $ fromIntegral $ w .&. 0x0f

----------------------------------------------------------------

decodeHeader :: SGet DNSHeader
decodeHeader = DNSHeader <$> decodeIdentifier
                         <*> decodeFlags
                         <*> decodeQdCount
                         <*> decodeAnCount
                         <*> decodeNsCount
                         <*> decodeArCount
  where
    decodeIdentifier = getInt16
    decodeQdCount = getInt16
    decodeAnCount = getInt16
    decodeNsCount = getInt16
    decodeArCount = getInt16

----------------------------------------------------------------

decodeQueries :: Int -> SGet [Question]
decodeQueries n = replicateM n decodeQuery

decodeType :: SGet TYPE
decodeType = intToType <$> getInt16

decodeQuery :: SGet Question
decodeQuery = Question <$> decodeDomain
                       <*> (decodeType <* ignoreClass)

decodeRRs :: Int -> SGet [RR (RD ByteString)]
decodeRRs n = replicateM n decodeRR

decodeRR :: SGet (RR (RD ByteString))
decodeRR = do
    Question dom typ <- decodeQuery
    ttl <- decodeTTL
    len <- decodeRLen
    dat <- decodeRData typ len
    return ResourceRecord { rrname = dom
                          , rrtype = typ
                          , rrttl  = ttl
                          , rdlen  = len
                          , rdata  = dat
                          }
  where
    decodeTTL = fromIntegral <$> get32
    decodeRLen = getInt16

decodeRData :: TYPE -> Int -> SGet (RD ByteString)
decodeRData NS _ = RD_NS <$> decodeDomain
decodeRData MX _ = RD_MX <$> decodePreference <*> decodeDomain
  where
    decodePreference = getInt16
decodeRData CNAME _ = RD_CNAME <$> decodeDomain
decodeRData DNAME _ = RD_DNAME <$> decodeDomain
decodeRData TXT len = (RD_TXT . ignoreLength) <$> getNByteString len
  where
    ignoreLength = BS.tail
decodeRData A len  = (RD_A . toIPv4) <$> getNBytes len
decodeRData AAAA len  = (RD_AAAA . toIPv6 . combine) <$> getNBytes len
  where
    combine [] = []
    combine [_] = fail "combine"
    combine (a:b:cs) =  a * 256 + b : combine cs
decodeRData SOA _ = RD_SOA <$> decodeDomain
                           <*> decodeDomain
                           <*> decodeSerial
                           <*> decodeRefesh
                           <*> decodeRetry
                           <*> decodeExpire
                           <*> decodeMinumun
  where
    decodeSerial  = getInt32
    decodeRefesh  = getInt32
    decodeRetry   = getInt32
    decodeExpire  = getInt32
    decodeMinumun = getInt32
decodeRData PTR _ = RD_PTR <$> decodeDomain
decodeRData SRV _ = RD_SRV <$> decodePriority
                           <*> decodeWeight
                           <*> decodePort
                           <*> decodeDomain
  where
    decodePriority = getInt16
    decodeWeight   = getInt16
    decodePort     = getInt16

decodeRData _  len = RD_OTH <$> getNByteString len

----------------------------------------------------------------

decodeDomain :: SGet Domain
decodeDomain = do
    pos <- getPosition
    c <- getInt8
    if c == 0 then
        return ""
      else do
        let n = getValue c
        if isPointer c then do
            d <- getInt8
            let offset = n * 256 + d
            mo <- pop offset
            case mo of
                Nothing -> fail $ "decodeDomain: " ++ show offset
                Just o -> do
                    -- A pointer may refer to another pointer.
                    -- So, register this position for the domain.
                    push pos o
                    return o
          else do
            hs <- getNByteString n
            ds <- decodeDomain
            let dom = hs `BS.append` "." `BS.append` ds
            push pos dom
            return dom
  where
    getValue c = c .&. 0x3f
    isPointer c = testBit c 7 && testBit c 6

ignoreClass :: SGet ()
ignoreClass = () <$ get16