-- |
-- Module:     Network.DnsCache
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements an asynchronous, caching DNS resolver.

{-# LANGUAGE GADTs, TypeFamilies #-}

module Network.DnsCache
    ( -- * DNS cache
      DnsCache,
      withDnsCache,

      -- * DNS lookup
      resolveA,
      resolveAAAA,
      resolveMX,

      -- * DNS mass lookup
      MassResult(..),
      MassType(..),
      massLookup,
      massLookupReport
    )
    where

import qualified Control.Exception as E
import qualified Data.Map as M
import Control.Concurrent
import Control.ContStuff
import Data.Char
import Data.IP
import Data.List
import Data.Map (Map)
import Data.Ord
import Data.Time.Clock
import Network.DNS hiding (lookup)


-- | Command to a DNS cache.

data CacheCmd res
    = CacheGet Domain (Waiter res)
    | CacheQuit (IO ())
    | CacheRequest Domain (Bool -> IO ())
    | CacheSetNotFound Domain
    | CacheSetResolved Domain [res]


-- | DNS cache configuration.

data DnsCache =
    DnsCache {
      dnsAVar    :: MVar (CacheCmd IPv4),
      dnsAaaaVar :: MVar (CacheCmd IPv6),
      dnsMxVar   :: MVar (CacheCmd Domain),
      dnsResVar  :: MVar ResolverCmd
    }


-- | DNS cache entry.

data Entry res
    = NotFound
    | Resolving [Waiter res]
    | Resolved UTCTime [res]


-- | Result of a mass lookup.

data MassResult =
    MassResult {
      massDomain :: Domain,
      massA      :: [IPv4],
      massAAAA   :: [IPv6],
      massMX     :: [Domain]
    }
    deriving Show


-- | The type of resources to look up.

data MassType = MassA | MassAAAA | MassMX | MassAll


-- | DNS resolver command.

data ResolverCmd where
    Resolve ::
        Domain ->
        MVar (CacheCmd res) ->
        (Resolver -> Domain -> IO (Maybe [res])) ->
        Waiter res ->
        ResolverCmd
    ResolverQuit :: IO () -> ResolverCmd


-- | DNS resolver waiter.

type Waiter res = Maybe [res] -> IO ()


-- | DNS cache thread.

dnsCache :: NominalDiffTime -> MVar (CacheCmd res) -> IO ()
dnsCache timeout cmdVar =
    evalStateT M.empty . forever $ do
        cmd <- io $ takeMVar cmdVar
        case cmd of
          CacheGet dom waiter -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Nothing       -> io $ waiter Nothing
                Just NotFound -> io $ waiter Nothing
                Just (Resolved _ res) -> io $ waiter (Just res)
                Just (Resolving ws) ->
                    modify (M.insert dom (Resolving (waiter:ws)))

          CacheQuit c -> io c >> abort ()

          CacheRequest dom c -> do
              mCache <- getField (M.lookup dom)
              curTime <- io getCurrentTime
              case mCache of
                Just (Resolving _) -> io $ c False
                Just (Resolved time _) | diffUTCTime curTime time < timeout ->
                    io $ c False
                _ -> do
                    modify (M.insert dom (Resolving []))
                    io $ c True

          CacheSetNotFound dom -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Just (Resolving cs) -> io $ mapM_ ($ Nothing) cs
                _                   -> return ()
              modify (M.insert dom NotFound)

          CacheSetResolved dom res -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Just (Resolving cs) -> io $ mapM_ ($ Just res) cs
                _                   -> return ()
              curTime <- io getCurrentTime
              modify (M.insert dom (Resolved curTime res))


-- | Perform a mass lookup.

massLookup :: DnsCache -> MassType -> [Domain] -> IO (Map Domain MassResult)
massLookup dns mtype domains = do
    let numDomains = length domains
    resultVar <- newEmptyMVar
    massResolver dns mtype domains resultVar
    execStateT M.empty . replicateM_ numDomains $ do
        m@(MassResult dom _ _ _) <- io $ takeMVar resultVar
        modify (M.insert dom m)


-- | Perform a mass lookup with report function.

massLookupReport ::
    (Base m ~ IO, LiftBase m, Monad m) =>
    DnsCache -> MassType -> [Domain] -> (MassResult -> m ()) -> m ()
massLookupReport dns mtype domains rep = do
    let numDomains = length domains
    resultVar <- io newEmptyMVar
    io $ massResolver dns mtype domains resultVar
    replicateM_ numDomains $ io (takeMVar resultVar) >>= rep


-- | Mass resolver.

massResolver :: DnsCache -> MassType -> [Domain] -> MVar MassResult -> IO ()
massResolver dns mtype domains resultVar = do
    forM_ domains $ \domain ->
        forkIO $ do
            case mtype of
              MassA -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveA dns domain >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain res [] [])

              MassAAAA -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveAAAA dns domain >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain [] res [])

              MassMX -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveMX dns domain >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain [] [] res)

              MassAll -> do
                  aVar <- newEmptyMVar
                  aaaaVar <- newEmptyMVar
                  mxVar <- newEmptyMVar

                  forkIO $ resolveA dns domain >>= putMVar aVar . maybe [] id
                  forkIO $ resolveAAAA dns domain >>= putMVar aaaaVar . maybe [] id
                  forkIO $ resolveMX dns domain >>= putMVar mxVar . maybe [] id

                  a <- takeMVar aVar
                  aaaa <- takeMVar aaaaVar
                  mx <- takeMVar mxVar
                  putMVar resultVar (MassResult domain a aaaa mx)


-- | Resolve an A record.

resolveA :: DnsCache -> Domain -> IO (Maybe [IPv4])
resolveA cfg dom = do
    let cacheVar = dnsAVar cfg
    let resVar = dnsResVar cfg

    answerVar <- newEmptyMVar
    putMVar resVar (Resolve dom cacheVar lookupA (putMVar answerVar))
    takeMVar answerVar


-- | Resolve an AAAA record.

resolveAAAA :: DnsCache -> Domain -> IO (Maybe [IPv6])
resolveAAAA cfg dom = do
    let cacheVar = dnsAaaaVar cfg
    let resVar = dnsResVar cfg

    answerVar <- newEmptyMVar
    putMVar resVar (Resolve dom cacheVar lookupAAAA (putMVar answerVar))
    takeMVar answerVar


-- | Resolve an MX record.

resolveMX :: DnsCache -> Domain -> IO (Maybe [Domain])
resolveMX cfg dom = do
    let cacheVar = dnsMxVar cfg
    let resVar = dnsResVar cfg

    let mxLookup r d = fmap (map fst . sortBy (comparing snd)) <$> lookupMX r d

    answerVar <- newEmptyMVar
    putMVar resVar (Resolve dom cacheVar mxLookup (putMVar answerVar))
    takeMVar answerVar


-- | DNS resolver thread.

resolver :: MVar ResolverCmd -> IO ()
resolver cmdVar = do
    seed <- makeResolvSeed defaultResolvConf
    withResolver seed $ \resolver ->
        evalContT . forever $ do
            cmd <- io $ takeMVar cmdVar
            case cmd of
              Resolve dom' cacheVar dnsLookup waiter ->
                  io $ do
                      let dom = map toLower dom'
                      mayResVar <- newEmptyMVar
                      putMVar cacheVar (CacheRequest dom (putMVar mayResVar))
                      mayResolve <- takeMVar mayResVar

                      if mayResolve
                        then do
                            mRes <- dnsLookup resolver dom
                            case mRes of
                              Nothing ->
                                  putMVar cacheVar (CacheSetNotFound dom) >>
                                  waiter Nothing
                              Just res ->
                                  putMVar cacheVar (CacheSetResolved dom res) >>
                                  waiter (Just res)
                        else do
                            resVar <- newEmptyMVar
                            putMVar cacheVar (CacheGet dom (putMVar resVar))
                            takeMVar resVar >>= waiter

              ResolverQuit c -> io c >> abort ()


-- | Start a DNS cache with the given number of resolver threads and the
-- given cache timeout.

withDnsCache :: Int -> NominalDiffTime -> (DnsCache -> IO a) -> IO a
withDnsCache numThreads timeout comp = do
    resolverVar <- newEmptyMVar
    aVar <- newEmptyMVar
    aaaaVar <- newEmptyMVar
    mxVar <- newEmptyMVar

    replicateM_ numThreads . forkIO $ resolver resolverVar
    forkIO $ dnsCache timeout aVar
    forkIO $ dnsCache timeout aaaaVar
    forkIO $ dnsCache timeout mxVar

    let cfg = DnsCache {
                dnsAVar = aVar,
                dnsAaaaVar = aaaaVar,
                dnsMxVar = mxVar,
                dnsResVar = resolverVar
              }

    E.finally (comp cfg) $ do
        quitVar <- newEmptyMVar
        replicateM_ numThreads $ do
            putMVar resolverVar (ResolverQuit $ putMVar quitVar ())
            takeMVar quitVar

        putMVar aVar (CacheQuit (putMVar quitVar ()))
        putMVar aaaaVar (CacheQuit (putMVar quitVar ()))
        putMVar mxVar (CacheQuit (putMVar quitVar ()))
        replicateM_ 3 $ takeMVar quitVar