{-# LANGUAGE OverloadedStrings #-}
module Cerberus.Lib
    (serveProxy, ProxyOpts(..), CacheBackend(..)) where

import           Control.Arrow                     ((>>>))
import           Data.ByteString.Char8             (ByteString, pack)
import           Network.HTTP.Client               (newManager)
import           Network.HTTP.Client.TLS
import           Network.HTTP.ReverseProxy
import           Network.HTTP.Types
import           Network.Wai                       (Application, Middleware)
import           Network.Wai.Handler.Warp          (run)
import qualified Network.Wai.Internal              as I
import           Network.Wai.Middleware.Cache      (cache)
import qualified Network.Wai.Middleware.LRUCache   as LRUCache
import           Network.Wai.Middleware.RedisCache (ConnectInfo (..), PortID(..),
                                                    defaultConnectInfo)
import qualified Network.Wai.Middleware.RedisCache as RedisCache
import           Network.Wai.Middleware.Throttle
import           System.Log.Logger
import           System.Metrics.Counter            (Counter)
import qualified System.Metrics.Counter            as Counter

data CacheBackend = LRU | Redis deriving (Eq,Show)

data ProxyOpts = ProxyOpts { _req           :: Counter
                           , _reqF          :: Counter
                           , _reqC          :: Counter
                           , _cacheSize     :: Integer
                           , _localPort     :: Int
                           , _url           :: String
                           , _portN         :: Int
                           , _maxNbReqBySeq :: Integer
                           , _cacheBackend  :: CacheBackend
                           , _redisHost     :: String
                           , _redisPort     :: Int
                           , _redisPass     :: Maybe String
                           }

updateHeaders :: Header -> [Header] -> [Header]
updateHeaders h headers = let removed = filter (fst >>> (/= fst h)) headers
                          in h:removed

proxyRequest :: ByteString -> I.Request -> I.Request
proxyRequest url req =
  req {I.requestHeaderHost = Just url
      ,I.requestHeaders =
         updateHeaders ("Host",url)
                       (I.requestHeaders req)}

proxyToApp :: Counter -> String -> Int -> IO Application
proxyToApp reqF url portN = do
    mng <- newManager tlsManagerSettings
    return (waiProxyTo (\req -> do
                             debugM "Cerberus.Lib" (show req)
                             Counter.inc reqF
                             return (WPRModifiedRequestSecure
                                            (proxyRequest (pack url) req)
                                            (ProxyDest (pack url) portN)))
                           defaultOnExc
                           mng)

throttleSettings :: Integer -> ThrottleSettings
throttleSettings nbReqBySec =
  defaultThrottleSettings { throttleRate = 1
                          , throttleBurst = nbReqBySec
                          }

countRequests :: Counter -> Middleware
countRequests req app = newapp
   where newapp request respond = do
           Counter.inc req
           app request respond

_serveProxy :: Middleware -- ^ Cache Middleware
           -> ProxyOpts -- ^ Proxy Options
           -> IO ()
_serveProxy cacheM opts = do
  reverseApp <- proxyToApp (_reqF opts) (_url opts) (_portN opts)
  st <- initThrottler
  run (_localPort opts)
      ((countRequests (_req opts)
                   >>> throttle (throttleSettings (_maxNbReqBySeq opts)) st
                   >>> cacheM)
                  reverseApp)

serveProxy :: ProxyOpts -> IO ()
serveProxy opts
  | _cacheBackend opts == LRU = do
      cb <- LRUCache.newCacheBackend (Just (_cacheSize opts))
                                     (const . const . return $ True)
                                     (const . const . Counter.inc . _req $ opts)
                                     (const . const . Counter.inc . _reqC $ opts)
      _serveProxy (cache cb) opts
  | _cacheBackend opts == Redis = do
      cb <- RedisCache.newCacheBackend (defaultConnectInfo { connectHost = _redisHost opts
                                                           , connectPort = PortNumber (fromIntegral (_redisPort opts))
                                                           , connectAuth = fmap pack (_redisPass opts)})
                                       (const . const . return $ True)
                                       (const . const . Counter.inc . _req $ opts)
                                       (const . const . Counter.inc . _reqC $ opts)
      _serveProxy (cache cb) opts
  | otherwise = error ("Unsupported Cache Backend: " ++ show (_cacheBackend opts))