{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveFunctor #-}

module Hans.Nat.State (
    NatState(), HasNatState(..),
    newNatState,
    Flow(..),
    Session(..),

    otherSide,

    PortForward(..),

    -- ** Rules
    addUdpPortForward, removeUdpPortForward,
    addTcpPortForward, removeTcpPortForward,

    -- ** Queries
    udpForwardingActive, addUdpSession, shouldForwardUdp,
    tcpForwardingActive, addTcpSession, shouldForwardTcp,
  ) where

import           Hans.Addr (Addr,isWildcardAddr)
import           Hans.Config (Config(..))
import           Hans.Lens (Getting,view)
import           Hans.Network.Types (RouteInfo(..))
import           Hans.Tcp.Packet (TcpPort)
import           Hans.Threads (forkNamed)
import           Hans.Udp.Packet (UdpPort)

import           Control.Concurrent (ThreadId,threadDelay)
import           Control.Monad (forever)
import           Data.HashPSQ as Q
import           Data.Hashable (Hashable)
import           Data.IORef (IORef,newIORef,readIORef,atomicModifyIORef')
import           Data.List (find)
import           Data.Time.Clock
                     (UTCTime,getCurrentTime,NominalDiffTime,addUTCTime)
import           Data.Word (Word16)
import           GHC.Generics (Generic)


-- State -----------------------------------------------------------------------

-- | NOTE: as TcpPort and UdpPort are both type aliases to Word16, Flow isn't
-- parameterized on the port type.
data Flow local = Flow { flowLocal      :: !local
                       , flowLocalPort  :: !Word16
                       , flowRemote     :: !Addr
                       , flowRemotePort :: !Word16
                       } deriving (Functor,Eq,Ord,Generic,Show)

instance Hashable remote => Hashable (Flow remote)


data NatState =
  NatState { natTcpTable_ :: !NatTable
             -- ^ Active TCP flows

           , natTcpRules_ :: !(IORef [PortForward])
             -- ^ Ports that have been forwarded in the TCP layer

           , natUdpTable_ :: !NatTable
             -- ^ Active UDP flows

           , natUdpRules_ :: !(IORef [PortForward])
             -- ^ Ports that have been forwarded in the UDP layer

           , natReaper_   :: !ThreadId
             -- ^ When flows are active, this is the id of the reaping thread
           }

class HasNatState state where
  natState :: Getting r state NatState

instance HasNatState NatState where
  natState = id

data PortForward = PortForward { pfSourceAddr :: !Addr
                                 -- ^ Local address to listen on

                               , pfSourcePort :: !Word16
                                 -- ^ The port on this network stack to
                                 -- forward

                               , pfDestAddr :: !Addr
                                 -- ^ Destination machine to forward to

                               , pfDestPort :: !Word16
                                 -- ^ Destination port to forward to
                               }


newNatState :: Config -> IO NatState
newNatState cfg =
  do natTcpTable_ <- newNatTable cfg
     natTcpRules_ <- newIORef []
     natUdpTable_ <- newNatTable cfg
     natUdpRules_ <- newIORef []
     natReaper_   <- forkNamed "Nat.reaper" (reaper natTcpTable_ natUdpTable_)
     return NatState { .. }


-- Nat Tables ------------------------------------------------------------------

data Session = Session { sessLeft, sessRight :: !(Flow (RouteInfo Addr)) }

-- | Gives back the other end of the session.
otherSide :: Flow Addr -> Session -> Flow (RouteInfo Addr)
otherSide flow Session { .. } =
  if flowRemote flow == flowRemote sessLeft
     && flowRemotePort flow == flowRemotePort sessLeft
     then sessRight else sessLeft

sessionFlows :: Session -> (Flow Addr, Flow Addr)
sessionFlows Session { .. } = (fmap riSource sessLeft, fmap riSource sessRight)


type Sessions = Q.HashPSQ (Flow Addr) UTCTime Session

addSession :: UTCTime -> Session -> Sessions -> Sessions
addSession age a q =
  let (l,r) = sessionFlows a
   in Q.insert l age a (Q.insert r age a q)

removeOldest :: Sessions -> Sessions
removeOldest q =
  case Q.minView q of
    Just (k,_,a,q') -> Q.delete (fmap riSource (otherSide k a)) q'
    Nothing         -> q

removeSession :: Flow Addr -> Sessions -> Maybe (Session,Sessions)
removeSession flow q =
  case Q.deleteView flow q of
    Just (_,a,q') -> Just (a,Q.delete (fmap riSource (otherSide flow a)) q')
    Nothing       -> Nothing


data NatTable = NatTable { natConfig :: Config
                         , natTable  :: !(IORef Sessions)
                         }

newNatTable :: Config -> IO NatTable
newNatTable natConfig =
  do natTable <- newIORef Q.empty
     return NatTable { .. }

-- | Insert an entry into the NAT table.
insertNatTable :: Session -> NatTable -> IO ()
insertNatTable sess NatTable { .. } =
  do now <- getCurrentTime
     atomicModifyIORef' natTable $ \ q ->
       let q' = addSession now sess q
        in if Q.size q' > cfgNatMaxEntries natConfig
              then (removeOldest q', ())
              else (q', ())

-- | Remove entries from the NAT table, decrementing the size by the number of
-- entries that were removed.
expireEntries :: UTCTime -> NatTable -> IO ()
expireEntries now NatTable { .. } =
  atomicModifyIORef' natTable go
  where
  now' = addUTCTime (negate fourMinutes) now

  -- remove entries that are older than four minutes
  go q =
    case Q.minView q of
      Just (k,p,a,q')
        | p < now'  -> go (Q.delete (fmap riSource (otherSide k a)) q')
        | otherwise -> (q, ())

      Nothing -> (Q.empty, ())

-- | Lookup and touch an entry in the NAT table.
lookupNatTable :: Flow Addr -> NatTable -> IO (Maybe Session)
lookupNatTable key NatTable { .. } =
  do now <- getCurrentTime
     atomicModifyIORef' natTable $ \ q ->
       case removeSession key q of
         Just (a,q') -> (addSession now a q', Just a)
         Nothing     -> (q, Nothing)


-- Table Reaping ---------------------------------------------------------------

-- | Every two minutes, reap old entries from the TCP and UDP NAT tables.
reaper :: NatTable -> NatTable -> IO ()
reaper tcp udp = forever $
  do threadDelay (2 * 60 * 1000000) -- delay for two minutes

     now <- getCurrentTime
     expireEntries now tcp
     expireEntries now udp

fourMinutes :: NominalDiffTime
fourMinutes  = 4 * 60.0


-- Rules -----------------------------------------------------------------------

addTcpPortForward :: HasNatState state => state -> PortForward -> IO ()
addTcpPortForward state rule =
  do let NatState { .. } = view natState state
     atomicModifyIORef' natTcpRules_ (\rs -> (rule : rs, ()))

-- | Remove port forwarding for UDP based on source address and port number.
removeTcpPortForward :: HasNatState state => state -> Addr -> TcpPort -> IO ()
removeTcpPortForward state addr port =
  do let NatState { .. } = view natState state
     atomicModifyIORef' natTcpRules_ (\rs -> (filter keepRule rs, ()))
  where
  keepRule PortForward { .. } = pfSourceAddr /= addr || pfSourcePort /= port

addUdpPortForward :: HasNatState state => state -> PortForward -> IO ()
addUdpPortForward state rule =
  do let NatState { .. } = view natState state
     atomicModifyIORef' natUdpRules_ (\rs -> (rule : rs, ()))

-- | Remove port forwarding for UDP based on source address and port number.
removeUdpPortForward :: HasNatState state => state -> Addr -> UdpPort -> IO ()
removeUdpPortForward state addr port =
  do let NatState { .. } = view natState state
     atomicModifyIORef' natUdpRules_ (\rs -> (filter keepRule rs, ()))
  where
  keepRule PortForward { .. } = pfSourceAddr /= addr || pfSourcePort /= port


-- Queries ---------------------------------------------------------------------

-- | Lookup information about an active forwarding session.
tcpForwardingActive :: HasNatState state
                    => state -> Flow Addr -> IO (Maybe Session)
tcpForwardingActive state key =
  do let NatState { .. } = view natState state
     lookupNatTable key natTcpTable_


-- | Lookup information about an active forwarding session.
udpForwardingActive :: HasNatState state
                    => state -> Flow Addr -> IO (Maybe Session)
udpForwardingActive state key =
  do let NatState { .. } = view natState state
     lookupNatTable key natUdpTable_


-- | Insert a TCP forwarding entry into the NAT state.
addTcpSession :: HasNatState state => state -> Session -> IO ()
addTcpSession state sess =
  do let NatState { .. } = view natState state
     insertNatTable sess natTcpTable_


-- | Insert a UDP forwarding entry into the NAT state.
addUdpSession :: HasNatState state => state -> Session -> IO ()
addUdpSession state sess =
  do let NatState { .. } = view natState state
     insertNatTable sess natUdpTable_

ruleApplies :: Flow Addr -> PortForward -> Bool
ruleApplies Flow { .. } = \ PortForward { .. } ->
  flowLocalPort == pfSourcePort &&
  (flowLocal == pfSourceAddr || isWildcardAddr pfSourceAddr)


-- | Returns the forwarding rule to use, if this connection should be forwarded.
shouldForwardTcp :: HasNatState state
                 => state -> Flow Addr -> IO (Maybe PortForward)
shouldForwardTcp state flow =
  do let NatState { .. } = view natState state
     rules <- readIORef natTcpRules_
     return $! find (ruleApplies flow) rules


-- | Returns the forwarding rule to use, if this session should be forwarded
shouldForwardUdp :: HasNatState state
                 => state -> Flow Addr -> IO (Maybe PortForward)
shouldForwardUdp state flow =
  do let NatState { .. } = view natState state
     rules <- readIORef natUdpRules_
     return $! find (ruleApplies flow) rules