-- | Infer the remote IP address using headers
module Network.Wai.Middleware.RealIp
    ( realIp
    , realIpHeader
    , realIpTrusted
    , defaultTrusted
    , ipInRange
    ) where

import qualified Data.ByteString.Char8 as B8 (split, unpack)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai (Middleware, remoteHost, requestHeaders)
import Text.Read (readMaybe)

-- | Infer the remote IP address from the @X-Forwarded-For@ header,
-- trusting requests from any private IP address. See 'realIpHeader' and
-- 'realIpTrusted' for more information and options.
--
-- @since 3.1.5
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"

-- | Infer the remote IP address using the given header, trusting
-- requests from any private IP address. See 'realIpTrusted' for more
-- information and options.
--
-- @since 3.1.5
realIpHeader :: HeaderName -> Middleware
realIpHeader :: HeaderName -> Middleware
realIpHeader HeaderName
header =
    HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header forall a b. (a -> b) -> a -> b
$ \IP
ip -> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted

-- | Infer the remote IP address using the given header, but only if the
-- request came from an IP that is trusted by the provided predicate.
--
-- The last non-trusted address is used to replace the 'remoteHost' in
-- the 'Request', unless all present IP addresses are trusted, in which
-- case the first address is used. Invalid IP addresses are ignored, and
-- the remoteHost value remains unaltered if no valid IP addresses are
-- found.
--
-- Examples:
--
-- @ realIpTrusted "X-Forwarded-For" $ flip ipInRange "10.0.0.0/8" @
--
-- @ realIpTrusted "X-Real-Ip" $ \\ip -> any (ipInRange ip) defaultTrusted @
--
-- @since 3.1.5
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req Response -> IO ResponseReceived
respond = Application
app Request
req' Response -> IO ResponseReceived
respond
  where
    req' :: Request
req' = forall a. a -> Maybe a -> a
fromMaybe Request
req forall a b. (a -> b) -> a -> b
$ do
             (IP
ip, PortNumber
port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
             IP
ip' <- if IP -> Bool
isTrusted IP
ip
                      then RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp (Request -> RequestHeaders
requestHeaders Request
req) HeaderName
header IP -> Bool
isTrusted
                      else forall a. Maybe a
Nothing
             forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Request
req { remoteHost :: SockAddr
remoteHost = (IP, PortNumber) -> SockAddr
IP.toSockAddr (IP
ip', PortNumber
port) }

-- | Standard private IP ranges.
--
-- @since 3.1.5
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted = [ IPRange
"127.0.0.0/8"
                 , IPRange
"10.0.0.0/8"
                 , IPRange
"172.16.0.0/12"
                 , IPRange
"192.168.0.0/16"
                 , IPRange
"::1/128"
                 , IPRange
"fc00::/7"
                 ]

-- | Check if the given IP address is in the given range.
--
-- IPv4 addresses can be checked against IPv6 ranges, but testing an
-- IPv6 address against an IPv4 range is always 'False'.
--
-- @since 3.1.5
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False


findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
    case ([IP]
nonTrusted, [IP]
ips) of
      ([], [IP]
xs) -> forall a. [a] -> Maybe a
listToMaybe [IP]
xs
      ([IP]
xs, [IP]
_)  -> forall a. [a] -> Maybe a
listToMaybe forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [IP]
xs
  where
    -- account for repeated headers
    headerVals :: [ByteString]
headerVals = [ ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k forall a. Eq a => a -> a -> Bool
== HeaderName
header ]
    ips :: [IP]
ips = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall a. Read a => String -> Maybe a
readMaybe forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
    nonTrusted :: [IP]
nonTrusted = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips