-- This file is part of htalkat -- Copyright (C) 2021 Martin Bays -- -- This program is free software: you can redistribute it and/or modify -- it under the terms of version 3 of the GNU General Public License as -- published by the Free Software Foundation, or any later version. -- -- You should have received a copy of the GNU General Public License -- along with this program. If not, see http://www.gnu.org/licenses/. {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module RelayStream where import Control.Concurrent import Control.Exception (SomeException, handle) import Control.Monad (foldM_, forever, unless, void, when) import System.Timeout (timeout) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Data.Text.Encoding.Error as T import qualified Data.Text.Lazy as T import qualified Data.Text.Lazy.Encoding as T import qualified Network.Socket as S import qualified Network.Socket.ByteString.Lazy as SL import qualified Network.TLS as TLS import qualified Time.System as TM import qualified Time.Types as TM import Mundanities import TimedText data WriteOrder = WriteFirst | WriteSecond deriving (Eq,Ord,Show) relayStream :: TLS.Context -> WriteOrder -> S.Socket -> IO () relayStream ctxt ord dSock = do receivedHandshake <- newEmptyMVar finished <- newEmptyMVar rawInChan <- newChan let abort = putMVar finished () abortOnErr = handle abortHandler where abortHandler :: Monoid a => SomeException -> IO a abortHandler _ = abort >> pure mempty recvAll = do b <- TLS.recvData ctxt case BS.uncons b of Nothing -> abort Just (h,_) -> do ok <- tryReadMVar receivedHandshake >>= \case Just ok -> pure ok Nothing -> do let isHandshakeByte = h == introByte putMVar receivedHandshake isHandshakeByte if isHandshakeByte then pure True else abort >> pure False if ok then writeChan rawInChan b >> recvAll else writeChan rawInChan BS.empty sendHandshake = do when (ord == WriteSecond) . void $ readMVar receivedHandshake TLS.sendData ctxt $ BL.singleton introByte sockMV <- newEmptyMVar sockThread <- forkIO $ putMVar sockMV . fst =<< S.accept dSock _ <- forkIO $ do sock <- readMVar sockMV abortOnErr sendHandshake tsOutChan <- newChan rawOutChan <- newChan _ <- forkIO . abortOnErr $ do writeList2Chan rawOutChan . T.unpack . T.decodeUtf8With T.lenientDecode =<< SL.getContents sock abort pausesThread <- forkIO $ insertPauses rawOutChan tsOutChan abortOnErr $ sendAll tsOutChan killThread pausesThread _ <- forkIO $ do tsInChan <- newChan decodeTTThread <- forkIO $ writeList2Chan tsInChan . decodeTimedText . BL.fromChunks =<< getChanContents rawInChan _ <- forkIO . abortOnErr $ relayTimed tsInChan =<< readMVar sockMV abortOnErr recvAll killThread decodeTTThread _ <- takeMVar finished TLS.bye ctxt ignoreIOErr $ killThread sockThread tryTakeMVar sockMV >>= \case Nothing -> pure () Just sock -> S.gracefulClose sock 1000 where introByte = fromIntegral $ fromEnum 'T' insertPauses rawChan ttChan = TM.timeCurrentP >>= insertPauses' where insertPauses' e = do c <- readChan rawChan e' <- TM.timeCurrentP let ms = elapsedPToMS $ e' - e when (ms > 0) . writeChan ttChan . Left $ fromIntegral ms writeChan ttChan $ Right c insertPauses' e' sendAll ttChan = forever $ do readBufMV <- newMVar [] _ <- timeout sendTimeout . forever $ modifyMVar_ readBufMV . (pure .) . (:) =<< readChan ttChan rtt <- readMVar readBufMV unless (null rtt) . TLS.sendData ctxt . rechunk . encodeTimedText $ reverse rtt where rechunk = -- TLS.sendData sends one packet per chunk, while encodeTimedText -- returns a chunk per char, so it's important to rechunk. BL.fromStrict . BL.toStrict sendTimeout = 1000 * 300 relayTimed chan sock = foldM_ sendTimed' Nothing =<< getChanContents chan where sendTimed' :: Maybe TM.ElapsedP -> Either Int Char -> IO (Maybe TM.ElapsedP) sendTimed' Nothing (Right c) = do threadDelay bufferTime e <- TM.timeCurrentP sendTimed' (Just e) (Right c) sendTimed' Nothing _ = pure Nothing sendTimed' (Just e) (Right c) = do SL.sendAll sock . T.encodeUtf8 $ T.singleton c pure $ Just e sendTimed' (Just e) (Left n) = do delayed <- elapsedPToMS . flip (-) e <$> TM.timeCurrentP when (n > delayed) . threadDelay . (1000 *) $ n - delayed pure $ if n == pauseMax && n < delayed then Nothing else Just $ e + msToElapsedP n bufferTime = 1000 * 300 msToElapsedP :: Int -> TM.ElapsedP msToElapsedP ms | (s,ms') <- fromIntegral ms `divMod` 1000 = TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds $ 1000000 * ms') elapsedPToMS :: TM.ElapsedP -> Int elapsedPToMS (TM.ElapsedP (TM.Elapsed (TM.Seconds s)) (TM.NanoSeconds ns)) = fromIntegral $ s*1000 + ns `div` 1000000