{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.DNS.Cache (
    DNSCacheConf(..)
  , DNSCache
  , withDNSCache
  
  , lookup
  , lookupCache
  
  , Result(..)
  , resolve
  , resolveCache
  
  , wait
  ) where
import Control.Applicative ((<$>))
import Control.Concurrent.Lifted (threadDelay, fork, killThread)
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception.Lifted (bracket)
import Control.Monad (forever)
import Control.Monad.IO.Class
import Control.Monad.Trans.Control
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Short as B
import Data.IP (toHostAddress)
import Data.Time (getCurrentTime, addUTCTime, NominalDiffTime)
import Network.DNS hiding (lookup)
import Network.DNS.Cache.Cache
import qualified Network.DNS.Cache.Sync as S
import Network.DNS.Cache.Types
import Network.DNS.Cache.Utils
import Network.DNS.Cache.Value
import Prelude hiding (lookup)
data DNSCacheConf = DNSCacheConf {
  
  
  
    resolvConfs    :: [ResolvConf]
  
  , maxConcurrency :: Int
  
  , minTTL         :: NominalDiffTime
  
  , maxTTL         :: NominalDiffTime
  
  , negativeTTL    :: NominalDiffTime
  }
data DNSCache = DNSCache {
    cacheSeeds      :: [ResolvSeed]
  , cacheNofServers :: !Int
  , cacheRef        :: CacheRef
  , cacheActiveRef  :: S.ActiveRef
  , cacheConcVar    :: S.ConcVar
  , cacheConcLimit  :: Int
  , cacheMinTTL     :: NominalDiffTime
  , cacheMaxTTL     :: NominalDiffTime
  , cacheNegTTL     :: NominalDiffTime
  }
withDNSCache :: (MonadBaseControl IO m, MonadIO m)
             => DNSCacheConf -> (DNSCache -> m a) -> m a
withDNSCache conf func = do
    seeds <- mapM (liftIO . makeResolvSeed) (resolvConfs conf)
    let n = length seeds
    cacheref <- liftIO newCacheRef
    activeref <- liftIO S.newActiveRef
    lvar <- liftIO S.newConcVar
    let cache = DNSCache seeds n cacheref activeref lvar maxcon minttl maxttl negttl
    bracket (fork $ liftIO $ prune cacheref) killThread (const $ func cache)
  where
    maxcon = maxConcurrency conf
    minttl = minTTL conf
    maxttl = maxTTL conf
    negttl = negativeTTL conf
lookupPSQ :: DNSCache -> Domain -> IO (Key, Maybe (Prio, Entry))
lookupPSQ cache dom = do
    !mx <- lookupCacheRef key cacheref
    return (key,mx)
  where
    cacheref = cacheRef cache
    !key = B.toShort dom
lookupCache :: DNSCache -> Domain -> IO (Maybe HostAddress)
lookupCache cache dom = do
    mx <- resolveCache cache dom
    case mx of
        Nothing -> return Nothing
        Just ev -> return (fromEither ev)
lookup :: DNSCache -> Domain -> IO (Maybe HostAddress)
lookup cache dom = fromEither <$> resolve cache dom
resolveCache :: DNSCache -> Domain -> IO (Maybe (Either DNSError Result))
resolveCache _ dom
  | isIPAddr dom       = Just . Right . Numeric <$> return (tov4 dom)
  where
    tov4 = toHostAddress . read . BS.unpack
resolveCache cache dom = do
    (_, mx) <- lookupPSQ cache dom
    case mx of
        Nothing           -> return Nothing
        Just (_, Right v) -> Just . Right . Hit <$> rotate v
        Just (_, Left e)  -> Just . Left <$> return e
resolve :: DNSCache -> Domain -> IO (Either DNSError Result)
resolve _     dom
  | isIPAddr dom            = return $ Right $ Numeric $ tov4 dom
  where
    tov4 = toHostAddress . read . BS.unpack
resolve cache dom = do
    (key,mx) <- lookupPSQ cache dom
    case mx of
        Just (_,ev) -> case ev of
            Left  e -> Left <$> return e
            Right v -> Right . Hit <$> rotate v
        Nothing -> do
            
            
            ma <- S.lookupActiveRef key activeref
            case ma of
                Just avar -> S.listen avar
                Nothing   -> do
                    avar <- S.newActiveVar
                    S.insertActiveRef key avar activeref
                    x <- sendQuery cache dom
                    !res <- case x of
                        Left  err   -> insertNegative cache key err
                        Right []    -> insertNegative cache key UnexpectedRDATA
                        Right addrs -> insertPositive cache key addrs
                    S.deleteActiveRef key activeref
                    S.tell avar (toHit res)
                    return res
  where
    activeref = cacheActiveRef cache
    toHit (Right (Resolved addr)) = Right (Hit addr)
    toHit x                       = x
insertPositive :: DNSCache -> Key -> [(HostAddress, TTL)]
               -> IO (Either DNSError Result)
insertPositive _     _   []                   = error "insertPositive"
insertPositive cache key addrs@((addr,ttl):_) = do
    !ent <- positiveEntry $ map fst addrs
    !tim <- addUTCTime lifeTime <$> getCurrentTime
    insertCacheRef key tim ent cacheref
    return $! Right $ Resolved addr
  where
    minttl = cacheMinTTL cache
    maxttl = cacheMaxTTL cache
    !lifeTime = minttl `max` (maxttl `min` fromIntegral ttl)
    cacheref = cacheRef cache
insertNegative :: DNSCache -> Key -> DNSError -> IO (Either DNSError Result)
insertNegative cache key err = do
    !tim <- addUTCTime lifeTime <$> getCurrentTime
    insertCacheRef key tim (Left err) cacheref
    return $ Left err
  where
    lifeTime = cacheNegTTL cache
    cacheref = cacheRef cache
sendQuery :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)])
sendQuery cache dom = bracket setup teardown body
  where
    setup = waitIncrease cache
    teardown _ = decrease cache
    body _ = concResolv cache dom
waitIncrease :: DNSCache -> IO ()
waitIncrease cache = S.waitIncrease lvar lim
  where
    lvar = cacheConcVar cache
    lim = cacheConcLimit cache
decrease :: DNSCache -> IO ()
decrease cache = S.decrease lvar
  where
    lvar = cacheConcVar cache
concResolv :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)])
concResolv cache dom = withResolvers seeds $ \resolvers -> do
    eans <- resolv n resolvers dom
    return $ case eans of
        Left  err -> Left err
        Right ans -> fromDNSFormat ans getHostAddressandTTL
  where
    n = cacheNofServers cache
    seeds = cacheSeeds cache
    isA r = rrtype r == A
    unTag (RD_A ip) = ip
    unTag _         = error "unTag"
    toAddr = toHostAddress . unTag . rdata
    hostAddressandTTL r = (toAddr r, rrttl r)
    getHostAddressandTTL = map hostAddressandTTL . filter isA . answer
resolv :: Int -> [Resolver] -> Domain -> IO (Either DNSError DNSFormat)
resolv 1 resolvers dom = lookupRaw (head resolvers) dom A
resolv _ resolvers dom = do
    asyncs <- mapM async actions
    snd <$> waitAnyCancel asyncs
  where
    actions = map (\res -> lookupRaw res dom A) resolvers
wait :: DNSCache -> (Int -> Bool) -> IO ()
wait cache cond = S.wait lvar cond
  where
    lvar = cacheConcVar cache
prune :: CacheRef -> IO ()
prune cacheref = forever $ do
    threadDelay 10000000
    tim <- getCurrentTime
    pruneCacheRef tim cacheref