{-# LANGUAGE RankNTypes, MultiWayIf, ScopedTypeVariables, LambdaCase #-}

module Test.Sandwich.WebDriver.Internal.Ports (
  findFreePortOrException
  ) where

import Control.Exception
import Control.Retry
import Data.Maybe
import Data.String.Interpolate
import qualified Data.Text as T
import Network.Socket
import System.Random (randomRIO)
import Test.Sandwich.WebDriver.Internal.Util

firstUserPort :: PortNumber
firstUserPort :: PortNumber
firstUserPort = PortNumber
1024

highestPort :: PortNumber
highestPort :: PortNumber
highestPort = PortNumber
65535

-- |Find an unused port in a given range
findFreePortInRange' :: RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' :: RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' RetryPolicy
policy IO PortNumber
getAcceptableCandidate = RetryPolicyM IO
-> (RetryStatus -> Maybe PortNumber -> IO Bool)
-> (RetryStatus -> IO (Maybe PortNumber))
-> IO (Maybe PortNumber)
forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying RetryPolicyM IO
RetryPolicy
policy (\RetryStatus
_retryStatus Maybe PortNumber
result -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Maybe PortNumber -> Bool
forall a. Maybe a -> Bool
isNothing Maybe PortNumber
result) (IO (Maybe PortNumber) -> RetryStatus -> IO (Maybe PortNumber)
forall a b. a -> b -> a
const IO (Maybe PortNumber)
findFreePortInRange'')
  where
    findFreePortInRange'' :: IO (Maybe PortNumber)
    findFreePortInRange'' :: IO (Maybe PortNumber)
findFreePortInRange'' = do
      PortNumber
candidate <- IO PortNumber
getAcceptableCandidate
      IO (Maybe PortNumber)
-> (SomeException -> IO (Maybe PortNumber))
-> IO (Maybe PortNumber)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (PortNumber -> IO PortNumber
tryOpenAndClosePort PortNumber
candidate IO PortNumber -> IO (Maybe PortNumber) -> IO (Maybe PortNumber)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PortNumber -> IO (Maybe PortNumber)
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber -> Maybe PortNumber
forall a. a -> Maybe a
Just PortNumber
candidate)) (\(SomeException
_ :: SomeException) -> Maybe PortNumber -> IO (Maybe PortNumber)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PortNumber
forall a. Maybe a
Nothing)
      where
        tryOpenAndClosePort :: PortNumber -> IO PortNumber
        tryOpenAndClosePort :: PortNumber -> IO PortNumber
tryOpenAndClosePort PortNumber
port = do
          Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
0
          Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
          let hostAddress :: HostAddress
hostAddress = (Word8, Word8, Word8, Word8) -> HostAddress
tupleToHostAddress (Word8
127, Word8
0, Word8
0, Word8
1)
          Socket -> SockAddr -> IO ()
bind Socket
sock (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port HostAddress
hostAddress)
          Socket -> IO ()
close Socket
sock
          PortNumber -> IO PortNumber
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber -> IO PortNumber) -> PortNumber -> IO PortNumber
forall a b. (a -> b) -> a -> b
$ PortNumber -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port

findFreePortInRange :: IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange :: IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange = RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' (Int -> RetryPolicy
limitRetries Int
50)

-- | Find an unused port in the ephemeral port range.
-- See https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers
-- This works without a timeout since there should always be a port in the somewhere;
-- it might be advisable to wrap in a timeout anyway.
findFreePort :: IO (Maybe PortNumber)
findFreePort :: IO (Maybe PortNumber)
findFreePort = IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange IO PortNumber
getNonEphemeralCandidate

findFreePortOrException :: IO PortNumber
findFreePortOrException :: IO PortNumber
findFreePortOrException = IO (Maybe PortNumber)
findFreePort IO (Maybe PortNumber)
-> (Maybe PortNumber -> IO PortNumber) -> IO PortNumber
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Just PortNumber
port -> PortNumber -> IO PortNumber
forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
port
  Maybe PortNumber
Nothing -> [Char] -> IO PortNumber
forall a. HasCallStack => [Char] -> a
error [Char]
"Couldn't find free port"

-- * Util

getNonEphemeralCandidate :: IO PortNumber
getNonEphemeralCandidate :: IO PortNumber
getNonEphemeralCandidate = do
  (PortNumber
ephemeralStart, PortNumber
ephemeralEnd) <- IO (Either Text (PortNumber, PortNumber))
getEphemeralPortRange IO (Either Text (PortNumber, PortNumber))
-> (Either Text (PortNumber, PortNumber)
    -> IO (PortNumber, PortNumber))
-> IO (PortNumber, PortNumber)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Text
_ -> (PortNumber, PortNumber) -> IO (PortNumber, PortNumber)
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber
49152, PortNumber
65535)
    Right (PortNumber, PortNumber)
range -> (PortNumber, PortNumber) -> IO (PortNumber, PortNumber)
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber, PortNumber)
range

  let numBelow :: PortNumber
numBelow = PortNumber
ephemeralStart PortNumber -> PortNumber -> PortNumber
forall a. Num a => a -> a -> a
- PortNumber
firstUserPort
  let numAbove :: PortNumber
numAbove = PortNumber
highestPort PortNumber -> PortNumber -> PortNumber
forall a. Num a => a -> a -> a
- PortNumber
ephemeralEnd

  Double
u :: Double <- (Double, Double) -> IO Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0, Double
1)

  let useLowerRange :: Bool
useLowerRange = Double
u Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< ((PortNumber -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numBelow) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (PortNumber -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numBelow Double -> Double -> Double
forall a. Num a => a -> a -> a
+ PortNumber -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numAbove))

  if | Bool
useLowerRange -> Integer -> PortNumber
forall a. Num a => Integer -> a
fromInteger (Integer -> PortNumber) -> IO Integer -> IO PortNumber
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Integer, Integer) -> IO Integer
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
firstUserPort, PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
ephemeralStart)
     | Bool
otherwise -> Integer -> PortNumber
forall a. Num a => Integer -> a
fromInteger (Integer -> PortNumber) -> IO Integer -> IO PortNumber
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Integer, Integer) -> IO Integer
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
ephemeralEnd, PortNumber -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
highestPort)

getEphemeralPortRange :: IO (Either T.Text (PortNumber, PortNumber))
getEphemeralPortRange :: IO (Either Text (PortNumber, PortNumber))
getEphemeralPortRange = IO (PortNumber, PortNumber)
-> IO (Either Text (PortNumber, PortNumber))
forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m) =>
m a -> m (Either Text a)
leftOnException' (IO (PortNumber, PortNumber)
 -> IO (Either Text (PortNumber, PortNumber)))
-> IO (PortNumber, PortNumber)
-> IO (Either Text (PortNumber, PortNumber))
forall a b. (a -> b) -> a -> b
$ do
  [Char]
contents <- [Char] -> IO [Char]
readFile [Char]
"/proc/sys/net/ipv4/ip_local_port_range"
  case ([Char] -> PortNumber) -> [[Char]] -> [PortNumber]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Char] -> PortNumber
forall a. Read a => [Char] -> a
read ([Char] -> [[Char]]
words [Char]
contents) of
    [PortNumber
p1, PortNumber
p2] -> (PortNumber, PortNumber) -> IO (PortNumber, PortNumber)
forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber
p1, PortNumber
p2)
    [PortNumber]
_ -> [Char] -> IO (PortNumber, PortNumber)
forall a. HasCallStack => [Char] -> a
error [i|Unexpected contents: '#{contents}'|]