-- |
-- Module:     Network.DnsCache
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements an asynchronous, caching DNS resolver.

{-# LANGUAGE FlexibleInstances, GADTs, RankNTypes, TypeFamilies #-}

module Network.DnsCache
    ( -- * DNS cache
      DnsCache,
      createDnsCache,
      freeDnsCache,
      withDnsCache,

      -- * Monadic interface
      DnsMonad(..),
      withDnsStateT,

      -- * DNS lookup
      resolveA,
      resolveAAAA,
      resolveMX,

      -- * DNS mass lookup
      MassResult(..),
      MassType(..),
      massLookup,
      massLookupReport,

      -- * Reexports
      Domain
    )
    where

import qualified Data.Map as M
import Control.Concurrent
import Control.ContStuff
import Data.Char
import Data.IORef
import Data.IP
import Data.List
import Data.Map (Map)
import Data.Ord
import Data.Time.Clock
import Network.DNS hiding (lookup)
import System.IO.Unsafe


-- | Command to a DNS cache.

data CacheCmd res
    = CacheGet Domain (Waiter res)
    | CacheQuit (IO ())
    | CacheRequest Domain (Bool -> IO ())
    | CacheSetNotFound Domain
    | CacheSetResolved Domain [res]


-- | DNS cache configuration.

data DnsCache =
    DnsCache {
      dnsAVar       :: MVar (CacheCmd IPv4),
      dnsAaaaVar    :: MVar (CacheCmd IPv6),
      dnsMxVar      :: MVar (CacheCmd Domain),
      dnsResVar     :: MVar ResolverCmd,
      dnsNumThreads :: Int
    }


-- | DNS cache entry.

data Entry res
    = NotFound
    | Resolving [Waiter res]
    | Resolved UTCTime [res]


-- | Result of a mass lookup.

data MassResult =
    MassResult {
      massDomain :: Domain,
      massA      :: [IPv4],
      massAAAA   :: [IPv6],
      massMX     :: [Domain]
    }
    deriving Show


-- | The type of resources to look up.

data MassType = MassA | MassAAAA | MassMX | MassAll


-- | DNS resolver command.

data ResolverCmd where
    Resolve ::
        Domain ->
        MVar (CacheCmd res) ->
        (Resolver -> Domain -> IO (Maybe [res])) ->
        Waiter res ->
        ResolverCmd
    ResolverQuit :: IO () -> ResolverCmd


-- | DNS resolver waiter.

type Waiter res = Maybe [res] -> IO ()


-- | Monads, which contain a DNS cache.

class MonadIO m => DnsMonad m where
    -- | Get the current DNS cache.
    getDnsCache :: m DnsCache

instance MonadIO m => DnsMonad (StateT r DnsCache m) where
    getDnsCache = get

instance DnsMonad IO where
    getDnsCache = readIORef globalCache

{-# NOINLINE globalCache #-}
globalCache :: IORef DnsCache
globalCache =
    unsafePerformIO (createDnsCache 1 10 >>= newIORef)


-- | DNS cache thread.

dnsCache :: NominalDiffTime -> MVar (CacheCmd res) -> IO ()
dnsCache timeout cmdVar =
    evalStateT M.empty . forever $ do
        cmd <- liftIO $ takeMVar cmdVar
        case cmd of
          CacheGet dom waiter -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Nothing       -> liftIO $ waiter Nothing
                Just NotFound -> liftIO $ waiter Nothing
                Just (Resolved _ res) -> liftIO $ waiter (Just res)
                Just (Resolving ws) ->
                    modify (M.insert dom (Resolving (waiter:ws)))

          CacheQuit c -> liftIO c >> abort ()

          CacheRequest dom c -> do
              mCache <- getField (M.lookup dom)
              curTime <- liftIO getCurrentTime
              case mCache of
                Just (Resolving _) -> liftIO $ c False
                Just (Resolved time _) | diffUTCTime curTime time < timeout ->
                    liftIO $ c False
                _ -> do
                    modify (M.insert dom (Resolving []))
                    liftIO $ c True

          CacheSetNotFound dom -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Just (Resolving cs) -> liftIO $ mapM_ ($ Nothing) cs
                _                   -> return ()
              modify (M.insert dom NotFound)

          CacheSetResolved dom res -> do
              mCache <- getField (M.lookup dom)
              case mCache of
                Just (Resolving cs) -> liftIO $ mapM_ ($ Just res) cs
                _                   -> return ()
              curTime <- liftIO getCurrentTime
              modify (M.insert dom (Resolved curTime res))


-- | Resolve an A record.

resolveA :: DnsMonad m => Domain -> m (Maybe [IPv4])
resolveA dom = getDnsCache >>= resolveA_ dom


-- | Resolve an A record.

resolveA_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [IPv4])
resolveA_ dom cfg =
    liftIO $ do
        let cacheVar = dnsAVar cfg
        let resVar = dnsResVar cfg

        answerVar <- newEmptyMVar
        putMVar resVar (Resolve dom cacheVar lookupA (putMVar answerVar))
        takeMVar answerVar


-- | Resolve an AAAA record.

resolveAAAA :: DnsMonad m => Domain -> m (Maybe [IPv6])
resolveAAAA dom = getDnsCache >>= resolveAAAA_ dom


-- | Resolve an AAAA record.

resolveAAAA_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [IPv6])
resolveAAAA_ dom cfg =
    liftIO $ do
        let cacheVar = dnsAaaaVar cfg
        let resVar = dnsResVar cfg

        answerVar <- newEmptyMVar
        putMVar resVar (Resolve dom cacheVar lookupAAAA (putMVar answerVar))
        takeMVar answerVar


-- | Resolve an MX record.

resolveMX :: DnsMonad m => Domain -> m (Maybe [Domain])
resolveMX dom = getDnsCache >>= resolveMX_ dom


-- | Resolve an MX record.

resolveMX_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [Domain])
resolveMX_ dom cfg =
    liftIO $ do
        let cacheVar = dnsMxVar cfg
        let resVar = dnsResVar cfg

        let mxLookup r d = fmap (map fst . sortBy (comparing snd))
                           <$> lookupMX r d

        answerVar <- newEmptyMVar
        putMVar resVar (Resolve dom cacheVar mxLookup (putMVar answerVar))
        takeMVar answerVar


-- | DNS resolver thread.

resolver :: MVar ResolverCmd -> IO ()
resolver cmdVar = do
    seed <- makeResolvSeed defaultResolvConf
    withResolver seed $ \resolver ->
        evalContT . forever $ do
            cmd <- liftIO $ takeMVar cmdVar
            case cmd of
              Resolve dom' cacheVar dnsLookup waiter ->
                  liftIO $ do
                      let dom = map toLower dom'
                      mayResVar <- newEmptyMVar
                      putMVar cacheVar (CacheRequest dom (putMVar mayResVar))
                      mayResolve <- takeMVar mayResVar

                      if mayResolve
                        then do
                            mRes <- dnsLookup resolver dom
                            case mRes of
                              Nothing ->
                                  putMVar cacheVar (CacheSetNotFound dom) >>
                                  waiter Nothing
                              Just res ->
                                  putMVar cacheVar (CacheSetResolved dom res) >>
                                  waiter (Just res)
                        else do
                            resVar <- newEmptyMVar
                            putMVar cacheVar (CacheGet dom (putMVar resVar))
                            takeMVar resVar >>= waiter

              ResolverQuit c -> liftIO c >> abort ()


-- ================== --
-- DNS cache creation --
-- ================== --


-- | Start a DNS cache with the given number of resolver threads and the
-- given cache timeout.

createDnsCache :: MonadIO m => Int -> NominalDiffTime -> m DnsCache
createDnsCache numThreads timeout = do
    resolverVar <- liftIO newEmptyMVar
    aVar <- liftIO newEmptyMVar
    aaaaVar <- liftIO newEmptyMVar
    mxVar <- liftIO newEmptyMVar

    liftIO $ do
        replicateM_ numThreads . forkIO $ resolver resolverVar
        forkIO $ dnsCache timeout aVar
        forkIO $ dnsCache timeout aaaaVar
        forkIO $ dnsCache timeout mxVar

    let dns = DnsCache {
                dnsAVar = aVar,
                dnsAaaaVar = aaaaVar,
                dnsMxVar = mxVar,
                dnsResVar = resolverVar,
                dnsNumThreads = numThreads
              }
    return dns


-- | Free existing DNS cache.

freeDnsCache :: MonadIO m => DnsCache -> m ()
freeDnsCache dns =
    liftIO $ do
        let aVar = dnsAVar dns
        let aaaaVar = dnsAaaaVar dns
        let mxVar = dnsMxVar dns
        let resolverVar = dnsResVar dns
        let numThreads = dnsNumThreads dns

        quitVar <- newEmptyMVar
        replicateM_ numThreads $ do
            putMVar resolverVar (ResolverQuit $ putMVar quitVar ())
            takeMVar quitVar

        putMVar aVar (CacheQuit (putMVar quitVar ()))
        putMVar aaaaVar (CacheQuit (putMVar quitVar ()))
        putMVar mxVar (CacheQuit (putMVar quitVar ()))
        replicateM_ 3 $ takeMVar quitVar


-- | Convenient wrapper around 'createDnsCache' and 'freeDnsCache'.

withDnsCache :: (HasExceptions m, MonadIO m) => Int -> NominalDiffTime -> (DnsCache -> m a) -> m a
withDnsCache numThreads timeout comp =
    bracket (createDnsCache numThreads timeout)
            freeDnsCache
            comp


-- ==================== --
-- Running computations --
-- ==================== --

-- | Run a 'StateT' computation with a DNS cache.

withDnsStateT ::
    (Applicative m, HasExceptions m, MonadIO m) =>
    Int -> NominalDiffTime -> StateT a DnsCache m a -> m a
withDnsStateT numThreads timeout comp =
    withDnsCache numThreads timeout (\dns -> evalStateT dns comp)


-- =================== --
-- DNS mass resolution --
-- =================== --


-- | Perform a mass lookup.

massLookup :: DnsMonad m => MassType -> [Domain] -> m (Map Domain MassResult)
massLookup mtype domains = do
    getDnsCache >>= massLookup_ mtype domains


-- | Perform a mass lookup.

massLookup_ :: MonadIO m => MassType -> [Domain] -> DnsCache -> m (Map Domain MassResult)
massLookup_ mtype domains dns =
    liftIO $ do
        let numDomains = length domains
        resultVar <- newEmptyMVar
        massResolver dns mtype domains resultVar
        execStateT M.empty . replicateM_ numDomains $ do
            m@(MassResult dom _ _ _) <- liftIO $ takeMVar resultVar
            modify (M.insert dom m)


-- | Perform a mass lookup with report function.

massLookupReport ::
    DnsMonad m =>
    MassType -> [Domain] -> (MassResult -> m ()) -> m ()
massLookupReport mtype domains rep =
    getDnsCache >>= massLookupReport_ mtype domains rep


-- | Perform a mass lookup with report function.

massLookupReport_ ::
    MonadIO m =>
    MassType -> [Domain] -> (MassResult -> m ()) -> DnsCache -> m ()
massLookupReport_ mtype domains rep dns = do
    let numDomains = length domains
    resultVar <- liftIO newEmptyMVar
    liftIO $ massResolver dns mtype domains resultVar
    replicateM_ numDomains $ liftIO (takeMVar resultVar) >>= rep


-- | Mass resolver.

massResolver :: DnsCache -> MassType -> [Domain] -> MVar MassResult -> IO ()
massResolver dns mtype domains resultVar = do
    forM_ domains $ \domain ->
        forkIO $ do
            case mtype of
              MassA -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveA_ domain dns >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain res [] [])

              MassAAAA -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveAAAA_ domain dns >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain [] res [])

              MassMX -> do
                  resVar <- newEmptyMVar
                  forkIO $ resolveMX_ domain dns >>= putMVar resVar . maybe [] id
                  res <- takeMVar resVar
                  putMVar resultVar (MassResult domain [] [] res)

              MassAll -> do
                  aVar <- newEmptyMVar
                  aaaaVar <- newEmptyMVar
                  mxVar <- newEmptyMVar

                  forkIO $ resolveA_ domain dns >>= putMVar aVar . maybe [] id
                  forkIO $ resolveAAAA_ domain dns >>= putMVar aaaaVar . maybe [] id
                  forkIO $ resolveMX_ domain dns >>= putMVar mxVar . maybe [] id

                  a <- takeMVar aVar
                  aaaa <- takeMVar aaaaVar
                  mx <- takeMVar mxVar
                  putMVar resultVar (MassResult domain a aaaa mx)