module Hans.Nat.State (
NatState(), HasNatState(..),
newNatState,
Flow(..),
Session(..),
otherSide,
PortForward(..),
addUdpPortForward, removeUdpPortForward,
addTcpPortForward, removeTcpPortForward,
udpForwardingActive, addUdpSession, shouldForwardUdp,
tcpForwardingActive, addTcpSession, shouldForwardTcp,
) where
import Hans.Addr (Addr,isWildcardAddr)
import Hans.Config (Config(..))
import Hans.Lens (Getting,view)
import Hans.Network.Types (RouteInfo(..))
import Hans.Tcp.Packet (TcpPort)
import Hans.Threads (forkNamed)
import Hans.Udp.Packet (UdpPort)
import Control.Concurrent (ThreadId,threadDelay)
import Control.Monad (forever)
import Data.HashPSQ as Q
import Data.Hashable (Hashable)
import Data.IORef (IORef,newIORef,readIORef,atomicModifyIORef')
import Data.List (find)
import Data.Time.Clock
(UTCTime,getCurrentTime,NominalDiffTime,addUTCTime)
import Data.Word (Word16)
import GHC.Generics (Generic)
data Flow local = Flow { flowLocal :: !local
, flowLocalPort :: !Word16
, flowRemote :: !Addr
, flowRemotePort :: !Word16
} deriving (Functor,Eq,Ord,Generic,Show)
instance Hashable remote => Hashable (Flow remote)
data NatState =
NatState { natTcpTable_ :: !NatTable
, natTcpRules_ :: !(IORef [PortForward])
, natUdpTable_ :: !NatTable
, natUdpRules_ :: !(IORef [PortForward])
, natReaper_ :: !ThreadId
}
class HasNatState state where
natState :: Getting r state NatState
instance HasNatState NatState where
natState = id
data PortForward = PortForward { pfSourceAddr :: !Addr
, pfSourcePort :: !Word16
, pfDestAddr :: !Addr
, pfDestPort :: !Word16
}
newNatState :: Config -> IO NatState
newNatState cfg =
do natTcpTable_ <- newNatTable cfg
natTcpRules_ <- newIORef []
natUdpTable_ <- newNatTable cfg
natUdpRules_ <- newIORef []
natReaper_ <- forkNamed "Nat.reaper" (reaper natTcpTable_ natUdpTable_)
return NatState { .. }
data Session = Session { sessLeft, sessRight :: !(Flow (RouteInfo Addr)) }
otherSide :: Flow Addr -> Session -> Flow (RouteInfo Addr)
otherSide flow Session { .. } =
if flowRemote flow == flowRemote sessLeft
&& flowRemotePort flow == flowRemotePort sessLeft
then sessRight else sessLeft
sessionFlows :: Session -> (Flow Addr, Flow Addr)
sessionFlows Session { .. } = (fmap riSource sessLeft, fmap riSource sessRight)
type Sessions = Q.HashPSQ (Flow Addr) UTCTime Session
addSession :: UTCTime -> Session -> Sessions -> Sessions
addSession age a q =
let (l,r) = sessionFlows a
in Q.insert l age a (Q.insert r age a q)
removeOldest :: Sessions -> Sessions
removeOldest q =
case Q.minView q of
Just (k,_,a,q') -> Q.delete (fmap riSource (otherSide k a)) q'
Nothing -> q
removeSession :: Flow Addr -> Sessions -> Maybe (Session,Sessions)
removeSession flow q =
case Q.deleteView flow q of
Just (_,a,q') -> Just (a,Q.delete (fmap riSource (otherSide flow a)) q')
Nothing -> Nothing
data NatTable = NatTable { natConfig :: Config
, natTable :: !(IORef Sessions)
}
newNatTable :: Config -> IO NatTable
newNatTable natConfig =
do natTable <- newIORef Q.empty
return NatTable { .. }
insertNatTable :: Session -> NatTable -> IO ()
insertNatTable sess NatTable { .. } =
do now <- getCurrentTime
atomicModifyIORef' natTable $ \ q ->
let q' = addSession now sess q
in if Q.size q' > cfgNatMaxEntries natConfig
then (removeOldest q', ())
else (q', ())
expireEntries :: UTCTime -> NatTable -> IO ()
expireEntries now NatTable { .. } =
atomicModifyIORef' natTable go
where
now' = addUTCTime (negate fourMinutes) now
go q =
case Q.minView q of
Just (k,p,a,q')
| p < now' -> go (Q.delete (fmap riSource (otherSide k a)) q')
| otherwise -> (q, ())
Nothing -> (Q.empty, ())
lookupNatTable :: Flow Addr -> NatTable -> IO (Maybe Session)
lookupNatTable key NatTable { .. } =
do now <- getCurrentTime
atomicModifyIORef' natTable $ \ q ->
case removeSession key q of
Just (a,q') -> (addSession now a q', Just a)
Nothing -> (q, Nothing)
reaper :: NatTable -> NatTable -> IO ()
reaper tcp udp = forever $
do threadDelay (2 * 60 * 1000000)
now <- getCurrentTime
expireEntries now tcp
expireEntries now udp
fourMinutes :: NominalDiffTime
fourMinutes = 4 * 60.0
addTcpPortForward :: HasNatState state => state -> PortForward -> IO ()
addTcpPortForward state rule =
do let NatState { .. } = view natState state
atomicModifyIORef' natTcpRules_ (\rs -> (rule : rs, ()))
removeTcpPortForward :: HasNatState state => state -> Addr -> TcpPort -> IO ()
removeTcpPortForward state addr port =
do let NatState { .. } = view natState state
atomicModifyIORef' natTcpRules_ (\rs -> (filter keepRule rs, ()))
where
keepRule PortForward { .. } = pfSourceAddr /= addr || pfSourcePort /= port
addUdpPortForward :: HasNatState state => state -> PortForward -> IO ()
addUdpPortForward state rule =
do let NatState { .. } = view natState state
atomicModifyIORef' natUdpRules_ (\rs -> (rule : rs, ()))
removeUdpPortForward :: HasNatState state => state -> Addr -> UdpPort -> IO ()
removeUdpPortForward state addr port =
do let NatState { .. } = view natState state
atomicModifyIORef' natUdpRules_ (\rs -> (filter keepRule rs, ()))
where
keepRule PortForward { .. } = pfSourceAddr /= addr || pfSourcePort /= port
tcpForwardingActive :: HasNatState state
=> state -> Flow Addr -> IO (Maybe Session)
tcpForwardingActive state key =
do let NatState { .. } = view natState state
lookupNatTable key natTcpTable_
udpForwardingActive :: HasNatState state
=> state -> Flow Addr -> IO (Maybe Session)
udpForwardingActive state key =
do let NatState { .. } = view natState state
lookupNatTable key natUdpTable_
addTcpSession :: HasNatState state => state -> Session -> IO ()
addTcpSession state sess =
do let NatState { .. } = view natState state
insertNatTable sess natTcpTable_
addUdpSession :: HasNatState state => state -> Session -> IO ()
addUdpSession state sess =
do let NatState { .. } = view natState state
insertNatTable sess natUdpTable_
ruleApplies :: Flow Addr -> PortForward -> Bool
ruleApplies Flow { .. } = \ PortForward { .. } ->
flowLocalPort == pfSourcePort &&
(flowLocal == pfSourceAddr || isWildcardAddr pfSourceAddr)
shouldForwardTcp :: HasNatState state
=> state -> Flow Addr -> IO (Maybe PortForward)
shouldForwardTcp state flow =
do let NatState { .. } = view natState state
rules <- readIORef natTcpRules_
return $! find (ruleApplies flow) rules
shouldForwardUdp :: HasNatState state
=> state -> Flow Addr -> IO (Maybe PortForward)
shouldForwardUdp state flow =
do let NatState { .. } = view natState state
rules <- readIORef natUdpRules_
return $! find (ruleApplies flow) rules