{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE OverloadedStrings   #-}


module Tunnel
    ( runClient
    , runServer
    , rrunTCPClient
    ) where

import           ClassyPrelude
import           Data.Maybe                    (fromJust)

import qualified Data.ByteString.Char8         as BC

import qualified Data.Conduit.Network.TLS      as N
import qualified Data.Streaming.Network        as N

import           Network.Socket                (HostName, PortNumber)
import qualified Network.Socket                as N hiding (recv, recvFrom,
                                                     send, sendTo)
import qualified Network.Socket.ByteString     as N

import qualified Network.WebSockets            as WS
import qualified Network.WebSockets.Connection as WS
import qualified Network.WebSockets.Stream     as WS

import           Control.Monad.Except
import qualified Network.Connection            as NC
import           System.IO                     (IOMode (ReadWriteMode))

import qualified Data.ByteString.Base64        as B64

import           Types
import           Protocols
import qualified Socks5
import           Logger
import qualified Credentials



rrunTCPClient :: N.ClientSettings -> (Connection -> IO a) -> IO a
rrunTCPClient :: ClientSettings -> (Connection -> IO a) -> IO a
rrunTCPClient ClientSettings
cfg Connection -> IO a
app = IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO a)
-> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (do
      (Socket
s,SockAddr
addr) <- ByteString -> Int -> Family -> IO (Socket, SockAddr)
N.getSocketFamilyTCP (ClientSettings -> ByteString
N.getHost ClientSettings
cfg) (ClientSettings -> Int
forall a. HasPort a => a -> Int
N.getPort ClientSettings
cfg) (ClientSettings -> Family
N.getAddrFamily ClientSettings
cfg)
      Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
s SocketOption
N.RecvBuffer Int
defaultRecvBufferSize
      Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
s SocketOption
N.SendBuffer Int
defaultSendBufferSize
      Int
so_mark_val <- IORef Int -> IO Int
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef Int
sO_MARK_Value
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
so_mark_val Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 Bool -> Bool -> Bool
&& SocketOption -> Bool
N.isSupportedSocketOption SocketOption
sO_MARK) (Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
s SocketOption
sO_MARK Int
so_mark_val)
      (Socket, SockAddr) -> IO (Socket, SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
s,SockAddr
addr)
    )
    (\(Socket, SockAddr)
r -> IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (Socket -> IO ()
N.close (Socket -> IO ()) -> Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst (Socket, SockAddr)
r) (\(SomeException
_ :: SomeException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
    (\(Socket
s, SockAddr
_) -> Connection -> IO a
app Connection :: IO (Maybe ByteString)
-> (ByteString -> IO ()) -> IO () -> Maybe Socket -> Connection
Connection
        { read :: IO (Maybe ByteString)
read = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Int -> IO ByteString
N.safeRecv Socket
s Int
defaultRecvBufferSize
        , write :: ByteString -> IO ()
write = Socket -> ByteString -> IO ()
N.sendAll Socket
s
        , close :: IO ()
close = Socket -> IO ()
N.close Socket
s
        , rawConnection :: Maybe Socket
rawConnection = Socket -> Maybe Socket
forall a. a -> Maybe a
Just Socket
s
        })

--
--  Pipes
--
tunnelingClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tunnelingClientP :: TunnelSettings
-> (Connection -> IO (m ())) -> Connection -> IO (m ())
tunnelingClientP cfg :: TunnelSettings
cfg@TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
..} Connection -> IO (m ())
app Connection
conn = IO (m ()) -> IO (m ())
forall a. IO (m a) -> IO (m a)
onError (IO (m ()) -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b. (a -> b) -> a -> b
$ do
  String -> IO ()
debug String
"Oppening Websocket stream"

  Stream
stream <- Connection -> IO Stream
connectionToStream Connection
conn
  let headers :: [(CI ByteString, ByteString)]
headers = if Bool -> Bool
not (ByteString -> Bool
forall mono. MonoFoldable mono => mono -> Bool
null ByteString
upgradeCredentials) then [(CI ByteString
"Authorization", ByteString
"Basic " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode ByteString
upgradeCredentials)] else []
  let hostname :: String
hostname = if Bool -> Bool
not (ByteString -> Bool
forall mono. MonoFoldable mono => mono -> Bool
null ByteString
hostHeader) then (ByteString -> String
BC.unpack ByteString
hostHeader) else String
serverHost

  m ()
ret <- Stream
-> String
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp (m ())
-> IO (m ())
forall a.
Stream
-> String
-> String
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp a
-> IO a
WS.runClientWithStream Stream
stream String
hostname (TunnelSettings -> String
toPath TunnelSettings
cfg) ConnectionOptions
WS.defaultConnectionOptions [(CI ByteString, ByteString)]
headers ClientApp (m ())
run

  String -> IO ()
debug String
"Closing Websocket stream"
  m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return m ()
ret

  where
    connectionToStream :: Connection -> IO Stream
connectionToStream Connection{Maybe Socket
IO (Maybe ByteString)
IO ()
ByteString -> IO ()
rawConnection :: Maybe Socket
close :: IO ()
write :: ByteString -> IO ()
read :: IO (Maybe ByteString)
rawConnection :: Connection -> Maybe Socket
close :: Connection -> IO ()
write :: Connection -> ByteString -> IO ()
read :: Connection -> IO (Maybe ByteString)
..} =  IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
WS.makeStream IO (Maybe ByteString)
read (ByteString -> IO ()
write (ByteString -> IO ())
-> (Maybe ByteString -> ByteString) -> Maybe ByteString -> IO ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> ByteString
forall lazy strict. LazySequence lazy strict => lazy -> strict
toStrict (ByteString -> ByteString)
-> (Maybe ByteString -> ByteString)
-> Maybe ByteString
-> ByteString
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust)
    onError :: IO (m a) -> IO (m a)
onError = (IO (m a) -> (SomeException -> IO (m a)) -> IO (m a))
-> (SomeException -> IO (m a)) -> IO (m a) -> IO (m a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (m a) -> (SomeException -> IO (m a)) -> IO (m a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (\(SomeException
e :: SomeException) -> m a -> IO (m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (m a -> IO (m a)) -> (String -> m a) -> String -> IO (m a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m a) -> (String -> Error) -> String -> m a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Error
WebsocketError (String -> IO (m a)) -> String -> IO (m a)
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
    run :: ClientApp (m ())
run Connection
cnx = do
      Connection -> Int -> IO ()
WS.forkPingThread Connection
cnx Int
websocketPingFrequencySec
      Connection -> IO (m ())
app (Connection -> Connection
forall a. ToConnection a => a -> Connection
toConnection Connection
cnx)


tlsClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tlsClientP :: TunnelSettings
-> (Connection -> IO (m ())) -> Connection -> IO (m ())
tlsClientP TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} Connection -> IO (m ())
app Connection
conn = IO (m ()) -> IO (m ())
forall a. IO (m a) -> IO (m a)
onError (IO (m ()) -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b. (a -> b) -> a -> b
$ do
    String -> IO ()
debug String
"Doing tls Handshake"

    ConnectionContext
context <- IO ConnectionContext
NC.initConnectionContext
    let socket :: Socket
socket = Maybe Socket -> Socket
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Socket -> Socket) -> Maybe Socket -> Socket
forall a b. (a -> b) -> a -> b
$ Connection -> Maybe Socket
rawConnection Connection
conn
    Handle
h <- Socket -> IOMode -> IO Handle
N.socketToHandle Socket
socket IOMode
ReadWriteMode

    Connection
connection <- ConnectionContext -> Handle -> ConnectionParams -> IO Connection
NC.connectFromHandle ConnectionContext
context Handle
h ConnectionParams
connectionParams
    m ()
ret <- Connection -> IO (m ())
app (Connection -> Connection
forall a. ToConnection a => a -> Connection
toConnection Connection
connection) IO (m ()) -> IO () -> IO (m ())
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`finally` Handle -> IO ()
forall (m :: * -> *). MonadIO m => Handle -> m ()
hClose Handle
h

    String -> IO ()
debug String
"Closing TLS"
    m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return m ()
ret

  where
    onError :: IO (m a) -> IO (m a)
onError = (IO (m a) -> (SomeException -> IO (m a)) -> IO (m a))
-> (SomeException -> IO (m a)) -> IO (m a) -> IO (m a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (m a) -> (SomeException -> IO (m a)) -> IO (m a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (\(SomeException
e :: SomeException) -> m a -> IO (m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (m a -> IO (m a)) -> (String -> m a) -> String -> IO (m a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m a) -> (String -> Error) -> String -> m a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Error
TlsError (String -> IO (m a)) -> String -> IO (m a)
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
    tlsSettings :: TLSSettings
tlsSettings = TLSSettingsSimple :: Bool -> Bool -> Bool -> TLSSettings
NC.TLSSettingsSimple { settingDisableCertificateValidation :: Bool
NC.settingDisableCertificateValidation = Bool
True
                                       , settingDisableSession :: Bool
NC.settingDisableSession = Bool
False
                                       , settingUseServerName :: Bool
NC.settingUseServerName = Bool
False
                                       }
    connectionParams :: ConnectionParams
connectionParams = ConnectionParams :: String
-> PortNumber
-> Maybe TLSSettings
-> Maybe ProxySettings
-> ConnectionParams
NC.ConnectionParams { connectionHostname :: String
NC.connectionHostname = if ByteString
tlsSNI ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
forall a. Monoid a => a
mempty then String
serverHost else ByteString -> String
BC.unpack ByteString
tlsSNI
                                           , connectionPort :: PortNumber
NC.connectionPort = PortNumber
serverPort
                                           , connectionUseSecure :: Maybe TLSSettings
NC.connectionUseSecure = TLSSettings -> Maybe TLSSettings
forall a. a -> Maybe a
Just TLSSettings
tlsSettings
                                           , connectionUseSocks :: Maybe ProxySettings
NC.connectionUseSocks = Maybe ProxySettings
forall a. Maybe a
Nothing
                                           }


--
--  Connectors
--
tcpConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
tcpConnection :: TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
tcpConnection TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} Connection -> IO (m ())
app = IO (m ()) -> IO (m ())
onError (IO (m ()) -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b. (a -> b) -> a -> b
$ do
  String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Oppening tcp connection to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
forall a. IsString a => String -> a
fromString String
serverHost String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
serverPort :: Int)

  m ()
ret <- ClientSettings -> (Connection -> IO (m ())) -> IO (m ())
forall a. ClientSettings -> (Connection -> IO a) -> IO a
rrunTCPClient (Int -> ByteString -> ClientSettings
N.clientSettingsTCP (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
serverPort) (String -> ByteString
forall a. IsString a => String -> a
fromString String
serverHost)) Connection -> IO (m ())
app

  String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Closing tcp connection to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
forall a. IsString a => String -> a
fromString String
serverHost String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
serverPort :: Int)
  m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return m ()
ret

  where
    onError :: IO (m ()) -> IO (m ())
onError = (IO (m ()) -> (SomeException -> IO (m ())) -> IO (m ()))
-> (SomeException -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (m ()) -> (SomeException -> IO (m ())) -> IO (m ())
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (\(SomeException
e :: SomeException) -> m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return (m () -> IO (m ())) -> m () -> IO (m ())
forall a b. (a -> b) -> a -> b
$ Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Index String -> String -> String
forall seq. IsSequence seq => Index seq -> seq -> seq
take Index String
10 (SomeException -> String
forall a. Show a => a -> String
show SomeException
e) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"user error") (Error -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m ()) -> Error -> m ()
forall a b. (a -> b) -> a -> b
$ String -> Error
TunnelError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e))



httpProxyConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
httpProxyConnection :: TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
httpProxyConnection TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} Connection -> IO (m ())
app = IO (m ()) -> IO (m ())
onError (IO (m ()) -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b. (a -> b) -> a -> b
$ do
  let settings :: ProxySettings
settings = Maybe ProxySettings -> ProxySettings
forall a. HasCallStack => Maybe a -> a
fromJust Maybe ProxySettings
proxySetting
  String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Oppening tcp connection to proxy " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ProxySettings -> String
forall a. Show a => a -> String
show ProxySettings
settings

  m ()
ret <- ClientSettings -> (Connection -> IO (m ())) -> IO (m ())
forall a. ClientSettings -> (Connection -> IO a) -> IO a
rrunTCPClient (Int -> ByteString -> ClientSettings
N.clientSettingsTCP (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ProxySettings -> PortNumber
port ProxySettings
settings)) (String -> ByteString
BC.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ ProxySettings -> String
host ProxySettings
settings)) ((Connection -> IO (m ())) -> IO (m ()))
-> (Connection -> IO (m ())) -> IO (m ())
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
    ()
_ <- ProxySettings -> Connection -> IO ()
sendConnectRequest ProxySettings
settings Connection
conn
    Maybe ByteString
responseM <- Int -> IO ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (Int
1000000 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10) (IO ByteString -> IO (Maybe ByteString))
-> IO ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Connection -> IO ByteString
readConnectResponse ByteString
forall a. Monoid a => a
mempty Connection
conn
    let response :: ByteString
response = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"No response of the proxy after 10s" Maybe ByteString
responseM

    if ByteString -> Bool
isAuthorized ByteString
response
    then Connection -> IO (m ())
app Connection
conn
    else m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return (m () -> IO (m ())) -> (String -> m ()) -> String -> IO (m ())
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Error -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m ()) -> (String -> Error) -> String -> m ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Error
ProxyForwardError (String -> IO (m ())) -> String -> IO (m ())
forall a b. (a -> b) -> a -> b
$ ByteString -> String
BC.unpack ByteString
response

  String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Closing tcp connection to proxy " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ProxySettings -> String
forall a. Show a => a -> String
show ProxySettings
settings
  m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return m ()
ret

  where
    credentialsToHeader :: (ByteString, ByteString) -> ByteString
credentialsToHeader (ByteString
user, ByteString
password) = ByteString
"Proxy-Authorization: Basic " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode (ByteString
user ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
":" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
password) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"
    sendConnectRequest :: ProxySettings -> Connection -> IO ()
sendConnectRequest ProxySettings
settings Connection
h = Connection -> ByteString -> IO ()
write Connection
h (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString
"CONNECT " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
forall a. IsString a => String -> a
fromString String
serverHost ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
":" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
forall a. IsString a => String -> a
fromString (PortNumber -> String
forall a. Show a => a -> String
show PortNumber
serverPort) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" HTTP/1.0\r\n"
                                  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"Host: " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
forall a. IsString a => String -> a
fromString String
serverHost ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
":" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
forall a. IsString a => String -> a
fromString (PortNumber -> String
forall a. Show a => a -> String
show PortNumber
serverPort) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"
                                  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
-> ((ByteString, ByteString) -> ByteString)
-> Maybe (ByteString, ByteString)
-> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
forall a. Monoid a => a
mempty (ByteString, ByteString) -> ByteString
credentialsToHeader (ProxySettings -> Maybe (ByteString, ByteString)
credentials ProxySettings
settings)
                                  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"

    readConnectResponse :: ByteString -> Connection -> IO ByteString
readConnectResponse ByteString
buff Connection
conn = do
      ByteString
response <- Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ByteString -> ByteString)
-> IO (Maybe ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO (Maybe ByteString)
read Connection
conn
      if ByteString
"\r\n\r\n" ByteString -> ByteString -> Bool
`BC.isInfixOf` ByteString
response
      then ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
buff ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
response
      else ByteString -> Connection -> IO ByteString
readConnectResponse (ByteString
buff ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
response) Connection
conn

    isAuthorized :: ByteString -> Bool
isAuthorized ByteString
response = ByteString
" 200 " ByteString -> ByteString -> Bool
`BC.isInfixOf` ByteString
response

    onError :: IO (m ()) -> IO (m ())
onError = (IO (m ()) -> (SomeException -> IO (m ())) -> IO (m ()))
-> (SomeException -> IO (m ())) -> IO (m ()) -> IO (m ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (m ()) -> (SomeException -> IO (m ())) -> IO (m ())
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (\(SomeException
e :: SomeException) -> m () -> IO (m ())
forall (m :: * -> *) a. Monad m => a -> m a
return (m () -> IO (m ())) -> m () -> IO (m ())
forall a b. (a -> b) -> a -> b
$ Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Index String -> String -> String
forall seq. IsSequence seq => Index seq -> seq -> seq
take Index String
10 (SomeException -> String
forall a. Show a => a -> String
show SomeException
e) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"user error") (Error -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m ()) -> Error -> m ()
forall a b. (a -> b) -> a -> b
$ String -> Error
ProxyConnectionError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e))

--
--  Client
--
runClient :: TunnelSettings -> IO ()
runClient :: TunnelSettings -> IO ()
runClient cfg :: TunnelSettings
cfg@TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} = do
  let withEndPoint :: (Connection -> IO (Either Error ())) -> IO (Either Error ())
withEndPoint = if Maybe ProxySettings -> Bool
forall a. Maybe a -> Bool
isJust Maybe ProxySettings
proxySetting then TunnelSettings
-> (Connection -> IO (Either Error ())) -> IO (Either Error ())
forall (m :: * -> *).
MonadError Error m =>
TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
httpProxyConnection TunnelSettings
cfg else TunnelSettings
-> (Connection -> IO (Either Error ())) -> IO (Either Error ())
forall (m :: * -> *).
MonadError Error m =>
TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
tcpConnection TunnelSettings
cfg
  let doTlsIf :: Bool -> (Connection -> IO (m ())) -> Connection -> IO (m ())
doTlsIf Bool
tlsNeeded Connection -> IO (m ())
app = if Bool
tlsNeeded then TunnelSettings
-> (Connection -> IO (m ())) -> Connection -> IO (m ())
forall (m :: * -> *).
MonadError Error m =>
TunnelSettings
-> (Connection -> IO (m ())) -> Connection -> IO (m ())
tlsClientP TunnelSettings
cfg Connection -> IO (m ())
app else Connection -> IO (m ())
app
  let withTunnel :: TunnelSettings
-> (Connection -> IO (Either Error ())) -> IO (Either Error ())
withTunnel TunnelSettings
cfg' Connection -> IO (Either Error ())
app = (Connection -> IO (Either Error ())) -> IO (Either Error ())
withEndPoint (Bool
-> (Connection -> IO (Either Error ()))
-> Connection
-> IO (Either Error ())
forall (m :: * -> *).
MonadError Error m =>
Bool -> (Connection -> IO (m ())) -> Connection -> IO (m ())
doTlsIf Bool
useTls ((Connection -> IO (Either Error ()))
 -> Connection -> IO (Either Error ()))
-> ((Connection -> IO (Either Error ()))
    -> Connection -> IO (Either Error ()))
-> (Connection -> IO (Either Error ()))
-> Connection
-> IO (Either Error ())
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TunnelSettings
-> (Connection -> IO (Either Error ()))
-> Connection
-> IO (Either Error ())
forall (m :: * -> *).
MonadError Error m =>
TunnelSettings
-> (Connection -> IO (m ())) -> Connection -> IO (m ())
tunnelingClientP TunnelSettings
cfg' ((Connection -> IO (Either Error ()))
 -> Connection -> IO (Either Error ()))
-> (Connection -> IO (Either Error ()))
-> Connection
-> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ Connection -> IO (Either Error ())
app)

  let app :: TunnelSettings -> a -> IO ()
app TunnelSettings
cfg' a
localH = do
        Either Error ()
ret <- TunnelSettings
-> (Connection -> IO (Either Error ())) -> IO (Either Error ())
withTunnel TunnelSettings
cfg' ((Connection -> IO (Either Error ())) -> IO (Either Error ()))
-> (Connection -> IO (Either Error ())) -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ \Connection
remoteH -> do
          Either Error ()
ret <- Connection
remoteH Connection -> Connection -> IO (Either Error ())
<==> a -> Connection
forall a. ToConnection a => a -> Connection
toConnection a
localH
          String -> IO ()
info (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"CLOSE tunnel :: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> TunnelSettings -> String
forall a. Show a => a -> String
show TunnelSettings
cfg'
          Either Error () -> IO (Either Error ())
forall (m :: * -> *) a. Monad m => a -> m a
return Either Error ()
ret

        Either Error () -> IO ()
handleError Either Error ()
ret

  case Protocol
protocol of
        Protocol
UDP -> (String, PortNumber) -> Int -> (UdpAppData -> IO ()) -> IO ()
runUDPServer (String
localBind, PortNumber
localPort) Int
udpTimeout (TunnelSettings -> UdpAppData -> IO ()
forall a. ToConnection a => TunnelSettings -> a -> IO ()
app TunnelSettings
cfg)
        Protocol
TCP -> (String, PortNumber) -> (AppData -> IO ()) -> IO ()
runTCPServer (String
localBind, PortNumber
localPort) (TunnelSettings -> AppData -> IO ()
forall a. ToConnection a => TunnelSettings -> a -> IO ()
app TunnelSettings
cfg)
        Protocol
STDIO -> (StdioAppData -> IO ()) -> IO ()
runSTDIOServer (TunnelSettings -> StdioAppData -> IO ()
forall a. ToConnection a => TunnelSettings -> a -> IO ()
app TunnelSettings
cfg)
        Protocol
SOCKS5 -> ServerSettings
-> TunnelSettings -> (TunnelSettings -> AppData -> IO ()) -> IO ()
runSocks5Server (PortNumber -> String -> ServerSettings
Socks5.ServerSettings PortNumber
localPort String
localBind) TunnelSettings
cfg TunnelSettings -> AppData -> IO ()
forall a. ToConnection a => TunnelSettings -> a -> IO ()
app




--
--  Server
--
runTlsTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTlsTunnelingServer :: (String, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTlsTunnelingServer endPoint :: (String, PortNumber)
endPoint@(String
bindTo, PortNumber
portNumber) (ByteString, Int) -> Bool
isAllowed = do
  String -> IO ()
info (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"WAIT for TLS connection on " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint

  TLSConfig -> (AppData -> IO ()) -> IO ()
N.runTCPServerTLS (HostPreference -> Int -> ByteString -> ByteString -> TLSConfig
N.tlsConfigBS (String -> HostPreference
forall a. IsString a => String -> a
fromString String
bindTo) (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
portNumber) ByteString
Credentials.certificate ByteString
Credentials.key) ((AppData -> IO ()) -> IO ()) -> (AppData -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppData
sClient ->
    AppData -> ConnectionOptions -> ServerApp -> IO ()
runApp AppData
sClient ConnectionOptions
WS.defaultConnectionOptions (SockAddr -> ((ByteString, Int) -> Bool) -> ServerApp
serverEventLoop (AppData -> SockAddr
N.appSockAddr AppData
sClient) (ByteString, Int) -> Bool
isAllowed)

  String -> IO ()
info String
"SHUTDOWN server"

  where
    runApp :: N.AppData -> WS.ConnectionOptions -> WS.ServerApp -> IO ()
    runApp :: AppData -> ConnectionOptions -> ServerApp -> IO ()
runApp AppData
appData ConnectionOptions
opts ServerApp
app = do
      Stream
stream <- IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
WS.makeStream (AppData -> IO ByteString
forall a. HasReadWrite a => a -> IO ByteString
N.appRead AppData
appData IO ByteString
-> (ByteString -> Maybe ByteString) -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ByteString
payload -> if ByteString
payload ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
forall a. Monoid a => a
mempty then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
payload) (AppData -> ByteString -> IO ()
forall a. HasReadWrite a => a -> ByteString -> IO ()
N.appWrite AppData
appData (ByteString -> IO ())
-> (Maybe ByteString -> ByteString) -> Maybe ByteString -> IO ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> ByteString
forall lazy strict. LazySequence lazy strict => lazy -> strict
toStrict (ByteString -> ByteString)
-> (Maybe ByteString -> ByteString)
-> Maybe ByteString
-> ByteString
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust)
      IO PendingConnection -> ServerApp -> ServerApp -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Stream -> ConnectionOptions -> IO PendingConnection
WS.makePendingConnectionFromStream Stream
stream ConnectionOptions
opts)
              (\PendingConnection
conn -> IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (Stream -> IO ()
WS.close (Stream -> IO ()) -> Stream -> IO ()
forall a b. (a -> b) -> a -> b
$ PendingConnection -> Stream
WS.pendingStream PendingConnection
conn) (\(SomeException
_ :: SomeException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
              ServerApp
app

runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer :: (String, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer endPoint :: (String, PortNumber)
endPoint@(String
host, PortNumber
port) (ByteString, Int) -> Bool
isAllowed = do
  String -> IO ()
info (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"WAIT for connection on " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint

  let srvSet :: ServerSettings
srvSet = Int -> ServerSettings -> ServerSettings
forall a. HasReadBufferSize a => Int -> a -> a
N.setReadBufferSize Int
defaultRecvBufferSize (ServerSettings -> ServerSettings)
-> ServerSettings -> ServerSettings
forall a b. (a -> b) -> a -> b
$ Int -> HostPreference -> ServerSettings
N.serverSettingsTCP (PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port) (String -> HostPreference
forall a. IsString a => String -> a
fromString String
host)
  IO Any -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Any -> IO ()) -> IO Any -> IO ()
forall a b. (a -> b) -> a -> b
$ ServerSettings -> (AppData -> IO ()) -> IO Any
forall a. ServerSettings -> (AppData -> IO ()) -> IO a
N.runTCPServer ServerSettings
srvSet ((AppData -> IO ()) -> IO Any) -> (AppData -> IO ()) -> IO Any
forall a b. (a -> b) -> a -> b
$ \AppData
sClient -> do
    Stream
stream <- IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
WS.makeStream (AppData -> IO ByteString
forall a. HasReadWrite a => a -> IO ByteString
N.appRead AppData
sClient IO ByteString
-> (ByteString -> Maybe ByteString) -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ByteString
payload -> if ByteString
payload ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
forall a. Monoid a => a
mempty then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
payload) (AppData -> ByteString -> IO ()
forall a. HasReadWrite a => a -> ByteString -> IO ()
N.appWrite AppData
sClient (ByteString -> IO ())
-> (Maybe ByteString -> ByteString) -> Maybe ByteString -> IO ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> ByteString
forall lazy strict. LazySequence lazy strict => lazy -> strict
toStrict (ByteString -> ByteString)
-> (Maybe ByteString -> ByteString)
-> Maybe ByteString
-> ByteString
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust)
    Stream -> ConnectionOptions -> ServerApp -> IO ()
runApp Stream
stream ConnectionOptions
WS.defaultConnectionOptions (SockAddr -> ((ByteString, Int) -> Bool) -> ServerApp
serverEventLoop (AppData -> SockAddr
N.appSockAddr AppData
sClient) (ByteString, Int) -> Bool
isAllowed)

  String -> IO ()
info String
"CLOSE server"

  where
    runApp :: WS.Stream -> WS.ConnectionOptions -> WS.ServerApp -> IO ()
    runApp :: Stream -> ConnectionOptions -> ServerApp -> IO ()
runApp Stream
socket ConnectionOptions
opts = IO PendingConnection -> ServerApp -> ServerApp -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Stream -> ConnectionOptions -> IO PendingConnection
WS.makePendingConnectionFromStream Stream
socket ConnectionOptions
opts)
                         (\PendingConnection
conn -> IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch (Stream -> IO ()
WS.close (Stream -> IO ()) -> Stream -> IO ()
forall a b. (a -> b) -> a -> b
$ PendingConnection -> Stream
WS.pendingStream PendingConnection
conn) (\(SomeException
_ :: SomeException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))

serverEventLoop :: N.SockAddr -> ((ByteString, Int) -> Bool) -> WS.PendingConnection -> IO ()
serverEventLoop :: SockAddr -> ((ByteString, Int) -> Bool) -> ServerApp
serverEventLoop SockAddr
sClient (ByteString, Int) -> Bool
isAllowed PendingConnection
pendingConn = do
  let path :: Maybe (Protocol, ByteString, Int)
path =  ByteString -> Maybe (Protocol, ByteString, Int)
fromPath (ByteString -> Maybe (Protocol, ByteString, Int))
-> (RequestHead -> ByteString)
-> RequestHead
-> Maybe (Protocol, ByteString, Int)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. RequestHead -> ByteString
WS.requestPath (RequestHead -> Maybe (Protocol, ByteString, Int))
-> RequestHead -> Maybe (Protocol, ByteString, Int)
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn
  let forwardedFor :: [(CI ByteString, ByteString)]
forwardedFor = (Element [(CI ByteString, ByteString)] -> Bool)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (\(header,val) -> CI ByteString
header CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString
"x-forwarded-for") ([(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)])
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$ RequestHead -> [(CI ByteString, ByteString)]
WS.requestHeaders (RequestHead -> [(CI ByteString, ByteString)])
-> RequestHead -> [(CI ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn
  String -> IO ()
info (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"NEW incoming connection from " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SockAddr -> String
forall a. Show a => a -> String
show SockAddr
sClient String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [(CI ByteString, ByteString)] -> String
forall a. Show a => a -> String
show [(CI ByteString, ByteString)]
forwardedFor
  case Maybe (Protocol, ByteString, Int)
path of
    Maybe (Protocol, ByteString, Int)
Nothing -> String -> IO ()
info String
"Rejecting connection" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PendingConnection -> ByteString -> IO ()
WS.rejectRequest PendingConnection
pendingConn ByteString
"Invalid tunneling information"
    Just (!Protocol
proto, !ByteString
rhost, !Int
rport) ->
      if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ByteString, Int) -> Bool
isAllowed (ByteString
rhost, Int
rport)
      then do
        String -> IO ()
info String
"Rejecting tunneling"
        PendingConnection -> ByteString -> IO ()
WS.rejectRequest PendingConnection
pendingConn ByteString
"Restriction is on, You cannot request this tunneling"
      else do
        Connection
conn <- PendingConnection -> IO Connection
WS.acceptRequest PendingConnection
pendingConn
        case Protocol
proto of
          Protocol
UDP -> (String, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPClient (ByteString -> String
BC.unpack ByteString
rhost, Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
rport) (\UdpAppData
cnx -> IO (Either Error ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either Error ()) -> IO ()) -> IO (Either Error ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Connection
forall a. ToConnection a => a -> Connection
toConnection Connection
conn Connection -> Connection -> IO (Either Error ())
<==> UdpAppData -> Connection
forall a. ToConnection a => a -> Connection
toConnection UdpAppData
cnx)
          Protocol
TCP -> (String, PortNumber) -> (AppData -> IO ()) -> IO ()
runTCPClient (ByteString -> String
BC.unpack ByteString
rhost, Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
rport) (\AppData
cnx -> IO (Either Error ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either Error ()) -> IO ()) -> IO (Either Error ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Connection
forall a. ToConnection a => a -> Connection
toConnection Connection
conn Connection -> Connection -> IO (Either Error ())
<==> AppData -> Connection
forall a. ToConnection a => a -> Connection
toConnection AppData
cnx)


runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runServer :: Bool
-> (String, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runServer Bool
useTLS = if Bool
useTLS then (String, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTlsTunnelingServer else (String, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer




--
--  Commons
--
toPath :: TunnelSettings -> String
toPath :: TunnelSettings -> String
toPath TunnelSettings{Bool
Int
String
Maybe ProxySettings
ByteString
PortNumber
Protocol
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} = String
"/" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
upgradePrefix String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/"
                            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
forall t. Textual t => t -> t
toLower (Protocol -> String
forall a. Show a => a -> String
show (Protocol -> String) -> Protocol -> String
forall a b. (a -> b) -> a -> b
$ if Protocol
protocol Protocol -> Protocol -> Bool
forall a. Eq a => a -> a -> Bool
== Protocol
UDP then Protocol
UDP else Protocol
TCP)
                            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
destHost String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PortNumber -> String
forall a. Show a => a -> String
show PortNumber
destPort

fromPath :: ByteString -> Maybe (Protocol, ByteString, Int)
fromPath :: ByteString -> Maybe (Protocol, ByteString, Int)
fromPath ByteString
path = let rets :: [ByteString]
rets = Char -> ByteString -> [ByteString]
BC.split Char
'/' (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Int -> ByteString -> ByteString
BC.drop Int
1 (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString
path
  in do
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([ByteString] -> Int
forall mono. MonoFoldable mono => mono -> Int
length [ByteString]
rets Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4)
    let [ByteString
_, ByteString
protocol, ByteString
h, ByteString
prt] = [ByteString]
rets
    Int
prt' <- String -> Maybe Int
forall c a.
(Element c ~ Char, MonoFoldable c, Read a) =>
c -> Maybe a
readMay (String -> Maybe Int)
-> (ByteString -> String) -> ByteString -> Maybe Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> String
BC.unpack (ByteString -> Maybe Int) -> ByteString -> Maybe Int
forall a b. (a -> b) -> a -> b
$ ByteString
prt :: Maybe Int
    Protocol
proto <- String -> Maybe Protocol
forall c a.
(Element c ~ Char, MonoFoldable c, Read a) =>
c -> Maybe a
readMay (String -> Maybe Protocol)
-> (ByteString -> String) -> ByteString -> Maybe Protocol
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> String
forall t. Textual t => t -> t
toUpper (String -> String)
-> (ByteString -> String) -> ByteString -> String
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> String
BC.unpack (ByteString -> Maybe Protocol) -> ByteString -> Maybe Protocol
forall a b. (a -> b) -> a -> b
$ ByteString
protocol :: Maybe Protocol
    (Protocol, ByteString, Int) -> Maybe (Protocol, ByteString, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Protocol
proto, ByteString
h, Int
prt')

handleError :: Either Error () -> IO ()
handleError :: Either Error () -> IO ()
handleError (Right ()) = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
handleError (Left Error
errMsg) =
  case Error
errMsg of
    ProxyConnectionError String
msg -> String -> IO ()
err String
"Cannot connect to the proxy" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    ProxyForwardError String
msg    -> String -> IO ()
err String
"Connection not allowed by the proxy" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    TunnelError String
msg          -> String -> IO ()
err String
"Cannot establish the connection to the server" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    LocalServerError String
msg     -> String -> IO ()
err String
"Cannot create the localServer, port already binded ?" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    WebsocketError String
msg       -> String -> IO ()
err String
"Cannot establish websocket connection with the server" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    TlsError String
msg             -> String -> IO ()
err String
"Cannot do tls handshake with the server" IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO ()
debugPP String
msg
    Other String
msg                -> String -> IO ()
debugPP String
msg

  where
    debugPP :: String -> IO ()
debugPP String
msg = String -> IO ()
debug (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"====\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"\n===="

myTry :: MonadError Error m => IO a -> IO (m ())
myTry :: IO a -> IO (m ())
myTry IO a
f = (SomeException -> m ())
-> (a -> m ()) -> Either SomeException a -> m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\(SomeException
e :: SomeException) -> Error -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m ()) -> (String -> Error) -> String -> m ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Error
Other (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e) (m () -> a -> m ()
forall a b. a -> b -> a
const (m () -> a -> m ()) -> m () -> a -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Either SomeException a -> m ())
-> IO (Either SomeException a) -> IO (m ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a -> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try IO a
f

(<==>) :: Connection -> Connection -> IO (Either Error ())
<==> :: Connection -> Connection -> IO (Either Error ())
(<==>) Connection
hTunnel Connection
hOther =
  IO () -> IO (Either Error ())
forall (m :: * -> *) a. MonadError Error m => IO a -> IO (m ())
myTry (IO () -> IO (Either Error ())) -> IO () -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
race_ (Connection -> Connection -> IO ()
propagateReads Connection
hTunnel Connection
hOther) (Connection -> Connection -> IO ()
propagateWrites Connection
hTunnel Connection
hOther)

propagateReads :: Connection -> Connection -> IO ()
propagateReads :: Connection -> Connection -> IO ()
propagateReads Connection
hTunnel Connection
hOther = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO (Maybe ByteString)
read Connection
hTunnel IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ByteString -> IO ()
write Connection
hOther (ByteString -> IO ())
-> (Maybe ByteString -> ByteString) -> Maybe ByteString -> IO ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust


propagateWrites :: Connection -> Connection -> IO ()
propagateWrites :: Connection -> Connection -> IO ()
propagateWrites Connection
hTunnel Connection
hOther = do
  ByteString
payload <- Maybe ByteString -> ByteString
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ByteString -> ByteString)
-> IO (Maybe ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO (Maybe ByteString)
read Connection
hOther
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
forall mono. MonoFoldable mono => mono -> Bool
null ByteString
payload) (Connection -> ByteString -> IO ()
write Connection
hTunnel ByteString
payload IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Connection -> Connection -> IO ()
propagateWrites Connection
hTunnel Connection
hOther)