{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}

module Hans.Layer.Dns (
    DnsHandle
  , runDnsLayer
  , DnsException

  , addNameServer
  , removeNameServer

  , HostName
  , HostEntry(..)
  , getHostByName
  , getHostByAddr
  ) where

import Hans.Address.IP4
import Hans.Channel
import Hans.Layer
import Hans.Layer.Udp as Udp
import Hans.Message.Dns
import Hans.Message.Udp
import Hans.Timers

import Control.Concurrent ( forkIO, MVar, newEmptyMVar, takeMVar, putMVar )
import Control.Monad ( mzero, guard, when )
import Data.Bits ( shiftR, (.&.), (.|.) )
import Data.Foldable ( foldl' )
import Data.List ( intercalate )
import Data.String ( fromString )
import Data.Typeable ( Typeable )
import Data.Word ( Word16 )
import MonadLib ( get, set )
import qualified Control.Exception as X
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict as Map


-- External Interface ----------------------------------------------------------

type DnsHandle = Channel (Dns ())

runDnsLayer :: DnsHandle -> UdpHandle -> IO ()
runDnsLayer h udp =
  do _ <- forkIO (loopLayer "dns" (emptyDnsState h udp) (receive h) id)
     return ()

data DnsException = NoNameServers
                    -- ^ No name servers have been configured
                  | OutOfServers
                    -- ^ Ran out of name servers to try
                  | DoesNotExist
                    -- ^ Unable to find any information about the host
                  | DnsRequestFailed
                    deriving (Show,Typeable)

instance X.Exception DnsException

addNameServer :: DnsHandle -> IP4 -> IO ()
addNameServer h addr =
  send h $ do state <- get
              set $! state { dnsNameServers = addr : dnsNameServers state }

removeNameServer :: DnsHandle -> IP4 -> IO ()
removeNameServer h addr =
  send h $ do state <- get
              set $! state
                { dnsNameServers = filter (/= addr) (dnsNameServers state) }

type HostName = String

data HostEntry = HostEntry { hostName      :: HostName
                           , hostAliases   :: [HostName]
                           , hostAddresses :: [IP4]
                           } deriving (Show)

getHostByName :: DnsHandle -> HostName -> IO HostEntry
getHostByName h host =
  do res <- newEmptyMVar
     send h (getHostEntry res (FromHost host))
     e <- takeMVar res
     case e of
       Right he -> return he
       Left err -> X.throwIO err


getHostByAddr :: DnsHandle -> IP4 -> IO HostEntry
getHostByAddr h addr =
  do res <- newEmptyMVar
     send h (getHostEntry res (FromIP4 addr))
     e <- takeMVar res
     case e of
       Right he -> return he
       Left err -> X.throwIO err


-- Handlers --------------------------------------------------------------------

type Dns = Layer DnsState

data DnsState = DnsState { dnsSelf        :: {-# UNPACK #-} !DnsHandle
                         , dnsUdpHandle   :: {-# UNPACK #-} !UdpHandle
                         , dnsNameServers :: ![IP4]
                         , dnsReqId       :: {-# UNPACK #-} !Word16
                         , dnsQueries     :: !(Map.Map Word16 DnsQuery)
                         , dnsTimeout     :: {-# UNPACK #-} !Milliseconds
                         }

emptyDnsState :: DnsHandle -> UdpHandle -> DnsState
emptyDnsState h udp = DnsState { dnsSelf        = h
                               , dnsUdpHandle   = udp
                               , dnsNameServers = []
                               , dnsReqId       = 1
                               , dnsQueries     = Map.empty
                               , dnsTimeout     = 180000 -- 3 minutes
                               }

-- LFSR: x^16 + x^14 + x^13 + x^11 + 1
--
-- 0xB400 ~ bit 15 .|. bit 13 .|. bit 12 .|. bit 10
stepReqId :: Word16 -> Word16
stepReqId w = (w `shiftR` 1) .|. (negate (w .&. 0x1) .&. 0xB400)

-- | Register a fresh request, along with a timer to reap the request after
-- 'dnsTimeout'.
registerRequest :: (Word16 -> DnsQuery) -> Dns Word16
registerRequest mk =
  do state <- get
     let reqId = dnsReqId state
     set state { dnsReqId   = stepReqId reqId
               , dnsQueries = Map.insert reqId (mk reqId) (dnsQueries state)
               }
     return reqId

registerTimeout :: Word16 -> Timer -> Dns ()
registerTimeout reqId timer =
  do DnsState { .. } <- get
     case Map.lookup reqId dnsQueries of
       Just query -> updateRequest reqId query { qTimeout = Just timer }
       Nothing    -> output (cancel timer)

updateRequest :: Word16 -> DnsQuery -> Dns ()
updateRequest reqId query =
  do state <- get
     set state { dnsQueries = Map.insert reqId query (dnsQueries state) }

lookupRequest :: Word16 -> Dns DnsQuery
lookupRequest reqId =
  do DnsState { .. } <- get
     case Map.lookup reqId dnsQueries of
       Just query -> return query
       Nothing    -> mzero

removeRequest :: Word16 -> Dns ()
removeRequest reqId =
  do state <- get
     set state { dnsQueries = Map.delete reqId (dnsQueries state) }


data Source = FromHost HostName
            | FromIP4 IP4
              deriving (Show)

sourceQType :: Source -> [QType]
sourceQType FromHost{} = [QType A]
sourceQType FromIP4{}  = [QType PTR]

sourceHost :: Source -> Name
sourceHost (FromHost h)            = toLabels h
sourceHost (FromIP4 (IP4 a b c d)) = let byte w = fromString (show w)
                                      in map byte [d,c,b,a] ++ ["in-addr","arpa"]

toLabels :: String -> Name
toLabels str = case break (== '.') str of
  (as,_:bs) -> fromString as : toLabels bs
  (as,_)    -> [fromString as]

getHostEntry :: DnsResult -> Source -> Dns ()
getHostEntry res src =
  do DnsState { .. } <- get

     -- make sure that there are name servers to work with
     when (null dnsNameServers) $
       do output (putError res NoNameServers)
          mzero

     -- register a upd handler on a fresh port, and query the name servers in
     -- order
     output $
       do port <- addUdpHandlerAnyPort dnsUdpHandle (serverResponse dnsSelf src)
          send dnsSelf (createRequest res dnsNameServers src port)


-- | Create the query packet, and register the request with the DNS layer.
-- Then, send a request to the first name server.
createRequest :: DnsResult -> [IP4] -> Source -> UdpPort -> Dns ()
createRequest res nss src port =
  do DnsState { .. } <- get
     reqId <- registerRequest (mkDnsQuery res nss port src)
     sendRequest reqId


-- | Send a request to the next name server in the queue.
sendRequest :: Word16 -> Dns ()
sendRequest reqId =
  do query <- lookupRequest reqId

     case qServers query of

       n:rest -> do updateRequest reqId query { qServers    = rest
                                              , qLastServer = Just n
                                              }
                    sendQuery n (qUdpPort query) reqId (qRequest query)

       -- out of servers to try
       [] -> do removeRequest reqId
                output (putError (qResult query) OutOfServers)

expireRequest :: Word16 -> Dns ()
expireRequest reqId =
  do DnsQuery { .. } <- lookupRequest reqId
     removeRequest reqId
     output (putError qResult OutOfServers)

-- | Handle the response from the server.
handleResponse :: Source -> IP4 -> UdpPort -> S.ByteString -> Dns ()
handleResponse src srcIp srcPort bytes =
  do guard (srcPort == 53)

     DNSPacket { .. } <- liftRight (parseDNSPacket bytes)
     let DNSHeader { .. } = dnsHeader
     DnsQuery { .. } <- lookupRequest dnsId

     -- require that the last name server we sent to was the one that responded,
     -- and that it responded with a response, not a request.
     guard (Just srcIp == qLastServer && not dnsQuery)

     if dnsRC == RespNoError
        then output (putResult qResult (parseHostEntry src dnsAnswers))
        else output (putError  qResult DnsRequestFailed)

     removeRequest dnsId
     DnsState { .. } <- get
     output $ do removeUdpHandler dnsUdpHandle qUdpPort
                 case qTimeout of
                   Just timeout -> cancel timeout
                   Nothing      -> return ()

parseHostEntry :: Source -> [RR] -> HostEntry
parseHostEntry (FromHost host) = parseAddr host
parseHostEntry (FromIP4 addr)  = parsePtr addr

-- | Parse the A and CNAME parts out of a response.
parseAddr :: HostName -> [RR] -> HostEntry
parseAddr host = foldl' processAnswer emptyHostEntry
  where

  emptyHostEntry = HostEntry { hostName      = host
                             , hostAliases   = []
                             , hostAddresses = [] }

  processAnswer he RR { .. } = case rrRData of
    RDA ip     -> he { hostAddresses = ip : hostAddresses he }
    RDCNAME ns -> he { hostName      = intercalate "." (map C8.unpack ns)
                     , hostAliases   = hostName he : hostAliases he }
    _          -> he

parsePtr :: IP4 -> [RR] -> HostEntry
parsePtr addr = foldl' processAnswer emptyHostEntry
  where
  emptyHostEntry = HostEntry { hostName      = ""
                             , hostAliases   = []
                             , hostAddresses = [addr] }

  processAnswer he RR { .. } = case rrRData of
    RDPTR name -> he { hostName = intercalate "." (map C8.unpack name) }
    _          -> he


-- Query Management ------------------------------------------------------------

type DnsResult = MVar (Either DnsException HostEntry)

putResult :: DnsResult -> HostEntry -> IO ()
putResult var he = putMVar var (Right he)

putError :: DnsResult -> DnsException -> IO ()
putError var err = putMVar var (Left err)

data DnsQuery = DnsQuery { qResult     :: DnsResult
                           -- ^ The handle back to the thread waiting fo the
                           -- HostEntry
                         , qUdpPort    :: !UdpPort
                           -- ^ The port this request is receiving on 
                         , qRequest    :: L.ByteString
                           -- ^ The packet to send
                         , qServers    :: [IP4]
                           -- ^ Name servers left to try
                         , qLastServer :: Maybe IP4
                           -- ^ The last server queried
                         , qTimeout    :: Maybe Timer
                           -- ^ The timer for the current request
                         }

mkDnsQuery :: DnsResult -> [IP4] -> UdpPort -> Source -> Word16 -> DnsQuery
mkDnsQuery res nss port src reqId =
  DnsQuery { qResult     = res
           , qUdpPort    = port
           , qRequest    = renderDNSPacket (mkDNSPacket host qs reqId)
           , qServers    = nss
           , qLastServer = Nothing
           , qTimeout    = Nothing
           }
  where
  host = sourceHost src
  qs   = sourceQType src


mkDNSPacket :: Name -> [QType] -> Word16 -> DNSPacket
mkDNSPacket name qs reqId =
  DNSPacket { dnsHeader            = hdr
            , dnsQuestions         = [ mkQuery q | q <- qs ]
            , dnsAnswers           = []
            , dnsAuthorityRecords  = []
            , dnsAdditionalRecords = []
            }
  where
  hdr = DNSHeader { dnsId     = reqId
                  , dnsQuery  = True
                  , dnsOpCode = OpQuery
                  , dnsAA     = False
                  , dnsTC     = False
                  , dnsRD     = True
                  , dnsRA     = False
                  , dnsRC     = RespNoError
                  }

  mkQuery qty = Query { qName  = name
                      , qType  = qty
                      , qClass = QClass IN
                      }


-- UDP Interaction -------------------------------------------------------------

-- | Send a UDP query to the server given
sendQuery :: IP4 -> UdpPort -> Word16 -> L.ByteString -> Dns ()
sendQuery nameServer sp reqId bytes =
  do DnsState { .. } <- get
     output $ do sendUdp dnsUdpHandle nameServer (Just sp) 53 bytes
                 expire <- delay dnsTimeout (send dnsSelf (expireRequest reqId) `X.finally` putStrLn "KILLED")
                 send dnsSelf (registerTimeout reqId expire)

-- | Queue the packet into the DNS layer for processing.
serverResponse :: DnsHandle -> Source -> UdpPort -> Udp.Handler
serverResponse dns src _ srcIp srcPort bytes =
  send dns (handleResponse src srcIp srcPort bytes)