{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}

-- | This module is inspired by the http-reverse-proxy package:
-- https://hackage.haskell.org/package/http-reverse-proxy

module Test.Sandwich.Contexts.ReverseProxy.TCP where

#ifndef mingw32_HOST_OS

import Control.Monad.IO.Unlift
import Data.Conduit
import qualified Data.Conduit.Network as DCN
import qualified Data.Conduit.Network.Unix as DCNU
import Data.Streaming.Network (setAfterBind)
import Data.String.Interpolate
import Network.Socket
import Relude
import Test.Sandwich (expectationFailure)
import UnliftIO.Async
import UnliftIO.Exception


withProxyToUnixSocket :: MonadUnliftIO m => FilePath -> (PortNumber -> m a) -> m a
withProxyToUnixSocket :: forall (m :: * -> *) a.
MonadUnliftIO m =>
FilePath -> (PortNumber -> m a) -> m a
withProxyToUnixSocket FilePath
socketPath PortNumber -> m a
f = do
  MVar PortNumber
portVar <- m (MVar PortNumber)
forall (m :: * -> *) a. MonadIO m => m (MVar a)
newEmptyMVar
  let ss :: ServerSettings
ss = Int -> HostPreference -> ServerSettings
DCN.serverSettings Int
0 HostPreference
"*"
         ServerSettings
-> (ServerSettings -> ServerSettings) -> ServerSettings
forall a b. a -> (a -> b) -> b
& (Socket -> IO ()) -> ServerSettings -> ServerSettings
forall a. HasAfterBind a => (Socket -> IO ()) -> a -> a
setAfterBind (\Socket
sock -> do
             Socket -> IO SockAddr
getSocketName Socket
sock IO SockAddr -> (SockAddr -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
               SockAddrInet PortNumber
port HostAddress
_ -> MVar PortNumber -> PortNumber -> IO ()
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m ()
putMVar MVar PortNumber
portVar PortNumber
port
               SockAddrInet6 PortNumber
port HostAddress
_ HostAddress6
_ HostAddress
_ -> MVar PortNumber -> PortNumber -> IO ()
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m ()
putMVar MVar PortNumber
portVar PortNumber
port
               SockAddr
x -> FilePath -> IO ()
forall (m :: * -> *) a.
(HasCallStack, MonadIO m) =>
FilePath -> m a
expectationFailure [i|withProxyToUnixSocket: expected to bind a TCP socket, but got other addr: #{x}|]
           )
  m Any -> (Async Any -> m a) -> m a
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
withAsync (IO Any -> m Any
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Any -> m Any) -> IO Any -> m Any
forall a b. (a -> b) -> a -> b
$ ServerSettings -> (AppData -> IO ()) -> IO Any
forall a. ServerSettings -> (AppData -> IO ()) -> IO a
DCN.runTCPServer ServerSettings
ss AppData -> IO ()
forall {ad}. HasReadWrite ad => ad -> IO ()
app IO Any -> IO () -> IO Any
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`onException` (MVar PortNumber -> PortNumber -> IO ()
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m ()
putMVar MVar PortNumber
portVar PortNumber
0)) ((Async Any -> m a) -> m a) -> (Async Any -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Async Any
_ ->
    MVar PortNumber -> m PortNumber
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
readMVar MVar PortNumber
portVar m PortNumber -> (PortNumber -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= PortNumber -> m a
f

  where
    app :: ad -> IO ()
app ad
appdata = ClientSettingsUnix -> (AppDataUnix -> IO ()) -> IO ()
forall a. ClientSettingsUnix -> (AppDataUnix -> IO a) -> IO a
DCNU.runUnixClient (FilePath -> ClientSettingsUnix
DCNU.clientSettings FilePath
socketPath) ((AppDataUnix -> IO ()) -> IO ())
-> (AppDataUnix -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppDataUnix
appdataServer ->
      IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_
        (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ad -> ConduitT () ByteString IO ()
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
DCN.appSource ad
appdata ConduitT () ByteString IO ()
-> ConduitT ByteString Void IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| AppDataUnix -> ConduitT ByteString Void IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
DCN.appSink AppDataUnix
appdataServer)
        (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ AppDataUnix -> ConduitT () ByteString IO ()
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
DCN.appSource AppDataUnix
appdataServer ConduitT () ByteString IO ()
-> ConduitT ByteString Void IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ad -> ConduitT ByteString Void IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
DCN.appSink ad
appdata)

#endif