{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE OverloadedStrings   #-}

module Protocols where

import           ClassyPrelude
import           Control.Concurrent        (forkFinally, threadDelay)
import qualified Data.HashMap.Strict       as H
import           System.IO                 hiding (hSetBuffering, hGetBuffering)

import qualified Data.ByteString.Char8     as BC

import qualified Data.Streaming.Network    as N

import           Network.Socket            (HostName, PortNumber)
import qualified Network.Socket            as N hiding (recv, recvFrom, send,
                                                 sendTo)
import qualified Network.Socket.ByteString as N

import           Data.Binary               (decode, encode)

import           Logger
import qualified Socks5
import           Types


runSTDIOServer :: (StdioAppData -> IO ()) -> IO ()
runSTDIOServer :: (StdioAppData -> IO ()) -> IO ()
runSTDIOServer StdioAppData -> IO ()
app = do
  BufferMode
stdin_old_buffering <- forall (m :: * -> *). MonadIO m => Handle -> m BufferMode
hGetBuffering Handle
stdin
  BufferMode
stdout_old_buffering <- forall (m :: * -> *). MonadIO m => Handle -> m BufferMode
hGetBuffering Handle
stdout

  forall (m :: * -> *). MonadIO m => Handle -> BufferMode -> m ()
hSetBuffering Handle
stdin (Maybe Int -> BufferMode
BlockBuffering (forall a. a -> Maybe a
Just Int
512))
  forall (m :: * -> *). MonadIO m => Handle -> BufferMode -> m ()
hSetBuffering Handle
stdout BufferMode
NoBuffering

  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ StdioAppData -> IO ()
app StdioAppData
StdioAppData

  forall (m :: * -> *). MonadIO m => Handle -> BufferMode -> m ()
hSetBuffering Handle
stdin BufferMode
stdin_old_buffering
  forall (m :: * -> *). MonadIO m => Handle -> BufferMode -> m ()
hSetBuffering Handle
stdout BufferMode
stdout_old_buffering
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CLOSE stdio server"

runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPServer :: (String, PortNumber) -> (AppData -> IO ()) -> IO ()
runTCPServer endPoint :: (String, PortNumber)
endPoint@(String
host, PortNumber
port) AppData -> IO ()
app = do
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"WAIT for tcp connection on " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint
  let srvSet :: ServerSettings
srvSet = forall a. HasReadBufferSize a => Int -> a -> a
N.setReadBufferSize Int
defaultRecvBufferSize forall a b. (a -> b) -> a -> b
$ Int -> HostPreference -> ServerSettings
N.serverSettingsTCP (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port) (forall a. IsString a => String -> a
fromString String
host)
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. ServerSettings -> (AppData -> IO ()) -> IO a
N.runTCPServer ServerSettings
srvSet AppData -> IO ()
app
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CLOSE tcp server on " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint

runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPClient :: (String, PortNumber) -> (AppData -> IO ()) -> IO ()
runTCPClient endPoint :: (String, PortNumber)
endPoint@(String
host, PortNumber
port) AppData -> IO ()
app = do
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CONNECTING to " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint
  let srvSet :: ClientSettings
srvSet = forall a. HasReadBufferSize a => Int -> a -> a
N.setReadBufferSize Int
defaultRecvBufferSize forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ClientSettings
N.clientSettingsTCP (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port) (String -> ByteString
BC.pack String
host)
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. ClientSettings -> (AppData -> IO a) -> IO a
N.runTCPClient ClientSettings
srvSet AppData -> IO ()
app
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CLOSE connection to " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint


runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPClient :: (String, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPClient endPoint :: (String, PortNumber)
endPoint@(String
host, PortNumber
port) UdpAppData -> IO ()
app = do
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"SENDING datagrammes to " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint
  forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (String -> Int -> IO (Socket, AddrInfo)
N.getSocketUDP String
host (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port)) (Socket -> IO ()
N.close forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ \(Socket
socket, AddrInfo
addrInfo) -> do
    MVar ByteString
sem <- forall (m :: * -> *) a. MonadIO m => m (MVar a)
newEmptyMVar
    UdpAppData -> IO ()
app UdpAppData { appAddr :: SockAddr
appAddr  = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo
                   , appSem :: MVar ByteString
appSem   = MVar ByteString
sem
                   , appRead :: IO ByteString
appRead  = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Int -> IO (ByteString, SockAddr)
N.recvFrom Socket
socket Int
4096
                   , appWrite :: ByteString -> IO ()
appWrite = \ByteString
payload -> forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> SockAddr -> IO ()
N.sendAllTo Socket
socket ByteString
payload (AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo)
                   }

  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CLOSE udp connection to " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint


runUDPServer :: (HostName, PortNumber) ->  Int -> (UdpAppData -> IO ()) -> IO ()
runUDPServer :: (String, PortNumber) -> Int -> (UdpAppData -> IO ()) -> IO ()
runUDPServer endPoint :: (String, PortNumber)
endPoint@(String
host, PortNumber
port) Int
cnxTimeout UdpAppData -> IO ()
app = do
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"WAIT for datagrames on " forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint
  IORef (HashMap SockAddr UdpAppData)
clientsCtx <- forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Int -> HostPreference -> IO Socket
N.bindPortUDP (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
port) (forall a. IsString a => String -> a
fromString String
host)) Socket -> IO ()
N.close (forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IORef (HashMap SockAddr UdpAppData) -> Socket -> IO ()
run IORef (HashMap SockAddr UdpAppData)
clientsCtx)
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"CLOSE udp server" forall a. Semigroup a => a -> a -> a
<> (String, PortNumber) -> String
toStr (String, PortNumber)
endPoint

  where
    addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString -> IO UdpAppData
    addNewClient :: IORef (HashMap SockAddr UdpAppData)
-> Socket -> SockAddr -> ByteString -> IO UdpAppData
addNewClient IORef (HashMap SockAddr UdpAppData)
clientsCtx Socket
socket SockAddr
addr ByteString
payload = do
      MVar ByteString
sem <- forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
newMVar ByteString
payload
      let appData :: UdpAppData
appData = UdpAppData { appAddr :: SockAddr
appAddr  = SockAddr
addr
                               , appSem :: MVar ByteString
appSem   = MVar ByteString
sem
                               , appRead :: IO ByteString
appRead  = forall (m :: * -> *) a. MonadIO m => MVar a -> m a
takeMVar MVar ByteString
sem
                               , appWrite :: ByteString -> IO ()
appWrite = \ByteString
payload' -> forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> SockAddr -> IO ()
N.sendAllTo Socket
socket ByteString
payload' SockAddr
addr
                               }
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef (HashMap SockAddr UdpAppData)
clientsCtx (\HashMap SockAddr UdpAppData
clients -> (forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
H.insert SockAddr
addr UdpAppData
appData HashMap SockAddr UdpAppData
clients, ()))
      forall (m :: * -> *) a. Monad m => a -> m a
return UdpAppData
appData

    removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO ()
    removeClient :: IORef (HashMap SockAddr UdpAppData) -> UdpAppData -> IO ()
removeClient IORef (HashMap SockAddr UdpAppData)
clientsCtx UdpAppData
clientCtx = do
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef (HashMap SockAddr UdpAppData)
clientsCtx (\HashMap SockAddr UdpAppData
clients -> (forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
H.delete (UdpAppData -> SockAddr
appAddr UdpAppData
clientCtx) HashMap SockAddr UdpAppData
clients, ()))
      String -> IO ()
debug String
"TIMEOUT connection"

    pushDataToClient :: UdpAppData -> ByteString -> IO ()
    pushDataToClient :: UdpAppData -> ByteString -> IO ()
pushDataToClient UdpAppData
clientCtx ByteString
payload = forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m ()
putMVar (UdpAppData -> MVar ByteString
appSem UdpAppData
clientCtx) ByteString
payload
      forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(SomeException
_ :: SomeException) -> String -> IO ()
debug forall a b. (a -> b) -> a -> b
$ String
"DROP udp packet, client thread dead")
     -- If we are unlucky the client's thread died before we had the time to push the data on a already full mutex
     -- and will leave us waiting forever for the mutex to empty. So catch the exeception and drop the message.
     -- Udp is not a reliable protocol so transmission failure should be handled by the application layer

    -- We run the server inside another thread in order to avoid Haskell runtime sending to the main thread 
    -- the exception  BlockedIndefinitelyOnMVar
    -- We dont use also MVar to wait for the end of the thread to avoid also receiving this exception
    run ::  IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO () 
    run :: IORef (HashMap SockAddr UdpAppData) -> Socket -> IO ()
run IORef (HashMap SockAddr UdpAppData)
clientsCtx Socket
socket = do
      ThreadId
_ <- forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (IORef (HashMap SockAddr UdpAppData) -> Socket -> IO ()
runEventLoop IORef (HashMap SockAddr UdpAppData)
clientsCtx Socket
socket) (\Either SomeException ()
_ -> String -> IO ()
debug String
"UdpServer died")
      Int -> IO ()
threadDelay (forall a. Bounded a => a
maxBound :: Int)     
          
    runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
    runEventLoop :: IORef (HashMap SockAddr UdpAppData) -> Socket -> IO ()
runEventLoop IORef (HashMap SockAddr UdpAppData)
clientsCtx Socket
socket = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
      (ByteString
payload, SockAddr
addr) <- Socket -> Int -> IO (ByteString, SockAddr)
N.recvFrom Socket
socket Int
4096
      Maybe UdpAppData
clientCtx <- forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup SockAddr
addr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef (HashMap SockAddr UdpAppData)
clientsCtx

      case Maybe UdpAppData
clientCtx of
        Just UdpAppData
clientCtx' -> UdpAppData -> ByteString -> IO ()
pushDataToClient UdpAppData
clientCtx' ByteString
payload
        Maybe UdpAppData
_               -> do
          UdpAppData
clientCtx <- IORef (HashMap SockAddr UdpAppData)
-> Socket -> SockAddr -> ByteString -> IO UdpAppData
addNewClient IORef (HashMap SockAddr UdpAppData)
clientsCtx Socket
socket SockAddr
addr ByteString
payload
          ThreadId
_ <- forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (forall (f :: * -> *) a. Functor f => f a -> f ()
void forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
cnxTimeout forall a b. (a -> b) -> a -> b
$ UdpAppData -> IO ()
app UdpAppData
clientCtx) (\Either SomeException ()
_ -> IORef (HashMap SockAddr UdpAppData) -> UdpAppData -> IO ()
removeClient IORef (HashMap SockAddr UdpAppData)
clientsCtx UdpAppData
clientCtx)
          forall (m :: * -> *) a. Monad m => a -> m a
return ()


runSocks5Server :: Socks5.ServerSettings -> TunnelSettings -> (TunnelSettings -> N.AppData -> IO()) -> IO ()
runSocks5Server :: ServerSettings
-> TunnelSettings -> (TunnelSettings -> AppData -> IO ()) -> IO ()
runSocks5Server socksSettings :: ServerSettings
socksSettings@Socks5.ServerSettings{String
PortNumber
$sel:bindOn:ServerSettings :: ServerSettings -> String
$sel:listenOn:ServerSettings :: ServerSettings -> PortNumber
bindOn :: String
listenOn :: PortNumber
..} TunnelSettings
cfg TunnelSettings -> AppData -> IO ()
inner = do
  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"Starting socks5 proxy " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show ServerSettings
socksSettings

  forall a. ServerSettings -> (AppData -> IO ()) -> IO a
N.runTCPServer (Int -> HostPreference -> ServerSettings
N.serverSettingsTCP (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
listenOn) (forall a. IsString a => String -> a
fromString String
bindOn)) forall a b. (a -> b) -> a -> b
$ \AppData
cnx -> do
    -- Get the auth request and response with a no Auth
    RequestAuth
authRequest <- forall a. Binary a => ByteString -> a
decode forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lazy strict. LazySequence lazy strict => strict -> lazy
fromStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HasReadWrite a => a -> IO ByteString
N.appRead AppData
cnx :: IO Socks5.RequestAuth
    String -> IO ()
debug forall a b. (a -> b) -> a -> b
$ String
"Socks5 authentification request " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RequestAuth
authRequest
    let responseAuth :: ByteString
responseAuth = forall a. Binary a => a -> ByteString
encode forall a b. (a -> b) -> a -> b
$ Int -> AuthMethod -> ResponseAuth
Socks5.ResponseAuth (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
Socks5.socksVersion) AuthMethod
Socks5.NoAuth
    forall a. HasReadWrite a => a -> ByteString -> IO ()
N.appWrite AppData
cnx (forall lazy strict. LazySequence lazy strict => lazy -> strict
toStrict ByteString
responseAuth)

    -- Get the request and update dynamically the tunnel config
    Request
request <- forall a. Binary a => ByteString -> a
decode forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lazy strict. LazySequence lazy strict => strict -> lazy
fromStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HasReadWrite a => a -> IO ByteString
N.appRead AppData
cnx :: IO Socks5.Request
    String -> IO ()
debug forall a b. (a -> b) -> a -> b
$ String
"Socks5 forward request " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Request
request
    let responseRequest :: ByteString
responseRequest =  forall a. Binary a => a -> ByteString
encode forall a b. (a -> b) -> a -> b
$ Int -> RetCode -> String -> PortNumber -> Response
Socks5.Response (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
Socks5.socksVersion) RetCode
Socks5.SUCCEEDED (Request -> String
Socks5.addr Request
request) (Request -> PortNumber
Socks5.port Request
request)
    let cfg' :: TunnelSettings
cfg' = TunnelSettings
cfg { destHost :: String
destHost = Request -> String
Socks5.addr Request
request, destPort :: PortNumber
destPort = Request -> PortNumber
Socks5.port Request
request }
    forall a. HasReadWrite a => a -> ByteString -> IO ()
N.appWrite AppData
cnx (forall lazy strict. LazySequence lazy strict => lazy -> strict
toStrict ByteString
responseRequest)

    TunnelSettings -> AppData -> IO ()
inner TunnelSettings
cfg' AppData
cnx

  String -> IO ()
info forall a b. (a -> b) -> a -> b
$ String
"Closing socks5 proxy " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show ServerSettings
socksSettings