{-# LANGUAGE RankNTypes, MultiWayIf, ScopedTypeVariables, LambdaCase #-}

module Test.Sandwich.WebDriver.Internal.Ports (
  findFreePortOrException
  ) where

import Control.Exception
import Control.Retry
import Data.Maybe
import Data.String.Interpolate
import qualified Data.Text as T
import Network.Socket
import System.Random (randomRIO)
import Test.Sandwich.WebDriver.Internal.Util

firstUserPort :: PortNumber
firstUserPort :: PortNumber
firstUserPort = PortNumber
1024

highestPort :: PortNumber
highestPort :: PortNumber
highestPort = PortNumber
65535

-- |Find an unused port in a given range
findFreePortInRange' :: RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' :: RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' RetryPolicy
policy IO PortNumber
getAcceptableCandidate = forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying RetryPolicy
policy (\RetryStatus
_retryStatus Maybe PortNumber
result -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Maybe a -> Bool
isNothing Maybe PortNumber
result) (forall a b. a -> b -> a
const IO (Maybe PortNumber)
findFreePortInRange'')
  where
    findFreePortInRange'' :: IO (Maybe PortNumber)
    findFreePortInRange'' :: IO (Maybe PortNumber)
findFreePortInRange'' = do
      PortNumber
candidate <- IO PortNumber
getAcceptableCandidate
      forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (PortNumber -> IO PortNumber
tryOpenAndClosePort PortNumber
candidate forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just PortNumber
candidate)) (\(SomeException
_ :: SomeException) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing)
      where
        tryOpenAndClosePort :: PortNumber -> IO PortNumber
        tryOpenAndClosePort :: PortNumber -> IO PortNumber
tryOpenAndClosePort PortNumber
port = do
          Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
0
          Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
          let hostAddress :: HostAddress
hostAddress = (Word8, Word8, Word8, Word8) -> HostAddress
tupleToHostAddress (Word8
127, Word8
0, Word8
0, Word8
1)
          Socket -> SockAddr -> IO ()
bind Socket
sock (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port HostAddress
hostAddress)
          Socket -> IO ()
close Socket
sock
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port

findFreePortInRange :: IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange :: IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange = RetryPolicy -> IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange' (Int -> RetryPolicy
limitRetries Int
50)

-- | Find an unused port in the ephemeral port range.
-- See https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers
-- This works without a timeout since there should always be a port in the somewhere;
-- it might be advisable to wrap in a timeout anyway.
findFreePort :: IO (Maybe PortNumber)
findFreePort :: IO (Maybe PortNumber)
findFreePort = IO PortNumber -> IO (Maybe PortNumber)
findFreePortInRange IO PortNumber
getNonEphemeralCandidate

findFreePortOrException :: IO PortNumber
findFreePortOrException :: IO PortNumber
findFreePortOrException = IO (Maybe PortNumber)
findFreePort forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Just PortNumber
port -> forall (m :: * -> *) a. Monad m => a -> m a
return PortNumber
port
  Maybe PortNumber
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Couldn't find free port"

-- * Util

getNonEphemeralCandidate :: IO PortNumber
getNonEphemeralCandidate :: IO PortNumber
getNonEphemeralCandidate = do
  (PortNumber
ephemeralStart, PortNumber
ephemeralEnd) <- IO (Either Text (PortNumber, PortNumber))
getEphemeralPortRange forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Text
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber
49152, PortNumber
65535)
    Right (PortNumber, PortNumber)
range -> forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber, PortNumber)
range

  let numBelow :: PortNumber
numBelow = PortNumber
ephemeralStart forall a. Num a => a -> a -> a
- PortNumber
firstUserPort
  let numAbove :: PortNumber
numAbove = PortNumber
highestPort forall a. Num a => a -> a -> a
- PortNumber
ephemeralEnd

  Double
u :: Double <- forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0, Double
1)

  let useLowerRange :: Bool
useLowerRange = Double
u forall a. Ord a => a -> a -> Bool
< ((forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numBelow) forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numBelow forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
numAbove))

  if | Bool
useLowerRange -> forall a. Num a => Integer -> a
fromInteger forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
firstUserPort, forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
ephemeralStart)
     | Bool
otherwise -> forall a. Num a => Integer -> a
fromInteger forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
ephemeralEnd, forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
highestPort)

getEphemeralPortRange :: IO (Either T.Text (PortNumber, PortNumber))
getEphemeralPortRange :: IO (Either Text (PortNumber, PortNumber))
getEphemeralPortRange = forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m) =>
m a -> m (Either Text a)
leftOnException' forall a b. (a -> b) -> a -> b
$ do
  [Char]
contents <- [Char] -> IO [Char]
readFile [Char]
"/proc/sys/net/ipv4/ip_local_port_range"
  case forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Read a => [Char] -> a
read ([Char] -> [[Char]]
words [Char]
contents) of
    [PortNumber
p1, PortNumber
p2] -> forall (m :: * -> *) a. Monad m => a -> m a
return (PortNumber
p1, PortNumber
p2)
    [PortNumber]
_ -> forall a. HasCallStack => [Char] -> a
error [i|Unexpected contents: '#{contents}'|]