-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Naive
  ( PaddingType (..)
  , addPaddingConduit
  , parseRequestForPadding
  , prepareResponseForPadding
  , removePaddingConduit
  ) where

import Control.Monad             (replicateM, unless)
import Control.Monad.IO.Class    (liftIO)
import Data.Binary.Builder       qualified as BB
import Data.ByteString           qualified as BS
import Data.ByteString.Char8     qualified as BS8
import Data.ByteString.Lazy      qualified as LBS
import Data.Conduit.Binary       qualified as CB
import Data.Maybe                (mapMaybe)
import Network.HTTP.Types.Header qualified as HT
import System.Random             (uniformR)
import System.Random.Stateful
    (applyAtomicGen, globalStdGen, runStateGen, uniformRM)

import Data.Conduit
import Network.Wai

randomPadding :: IO BS8.ByteString
randomPadding :: IO ByteString
randomPadding = forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen forall {g}. RandomGen g => g -> (ByteString, g)
generate AtomicGenM StdGen
globalStdGen
  where
    nonHuffman :: [Char]
nonHuffman = [Char]
"!#$()+<>?@[]^`{}"
    countNonHuffman :: Int
countNonHuffman = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
nonHuffman

    generate :: g -> (ByteString, g)
generate g
g0 = forall g a.
RandomGen g =>
g -> (StateGenM g -> State g a) -> (a, g)
runStateGen g
g0 forall a b. (a -> b) -> a -> b
$ \StateGenM g
gen -> do
        Int
len <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
32, Int
63) StateGenM g
gen
        [Char]
prefix <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
24 forall a b. (a -> b) -> a -> b
$ do
            Int
idx <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
0, Int
countNonHuffman forall a. Num a => a -> a -> a
- Int
1) StateGenM g
gen
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Char]
nonHuffman forall a. [a] -> Int -> a
!! Int
idx
        forall (m :: * -> *) a. Monad m => a -> m a
return ([Char] -> ByteString
BS8.pack ([Char]
prefix forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
len forall a. Num a => a -> a -> a
- Int
24) Char
'~'))

randInt :: Int -> Int -> IO Int
randInt :: Int -> Int -> IO Int
randInt Int
minv Int
maxv = forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen (forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
minv, Int
maxv)) AtomicGenM StdGen
globalStdGen

-- https://github.com/klzgrad/naiveproxy/blob/master/src/net/tools/naive/naive_protocol.h#L30C12-L30C23
data PaddingType = NoPadding
                 | Variant1
  deriving (Int -> PaddingType -> ShowS
[PaddingType] -> ShowS
PaddingType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [PaddingType] -> ShowS
$cshowList :: [PaddingType] -> ShowS
show :: PaddingType -> [Char]
$cshow :: PaddingType -> [Char]
showsPrec :: Int -> PaddingType -> ShowS
$cshowsPrec :: Int -> PaddingType -> ShowS
Show, PaddingType -> PaddingType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PaddingType -> PaddingType -> Bool
$c/= :: PaddingType -> PaddingType -> Bool
== :: PaddingType -> PaddingType -> Bool
$c== :: PaddingType -> PaddingType -> Bool
Eq, Eq PaddingType
PaddingType -> PaddingType -> Bool
PaddingType -> PaddingType -> Ordering
PaddingType -> PaddingType -> PaddingType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PaddingType -> PaddingType -> PaddingType
$cmin :: PaddingType -> PaddingType -> PaddingType
max :: PaddingType -> PaddingType -> PaddingType
$cmax :: PaddingType -> PaddingType -> PaddingType
>= :: PaddingType -> PaddingType -> Bool
$c>= :: PaddingType -> PaddingType -> Bool
> :: PaddingType -> PaddingType -> Bool
$c> :: PaddingType -> PaddingType -> Bool
<= :: PaddingType -> PaddingType -> Bool
$c<= :: PaddingType -> PaddingType -> Bool
< :: PaddingType -> PaddingType -> Bool
$c< :: PaddingType -> PaddingType -> Bool
compare :: PaddingType -> PaddingType -> Ordering
$ccompare :: PaddingType -> PaddingType -> Ordering
Ord)

parsePaddingType :: BS8.ByteString -> Maybe PaddingType
parsePaddingType :: ByteString -> Maybe PaddingType
parsePaddingType ByteString
"0" = forall a. a -> Maybe a
Just PaddingType
NoPadding
parsePaddingType ByteString
"1" = forall a. a -> Maybe a
Just PaddingType
Variant1
parsePaddingType ByteString
_   = forall a. Maybe a
Nothing

showPaddingType :: PaddingType -> BS8.ByteString
showPaddingType :: PaddingType -> ByteString
showPaddingType PaddingType
NoPadding = ByteString
"0"
showPaddingType PaddingType
Variant1  = ByteString
"1"

legacyPaddingHeader :: HT.HeaderName
legacyPaddingHeader :: HeaderName
legacyPaddingHeader = HeaderName
"Padding"

paddingTypeRequestHeader :: HT.HeaderName
paddingTypeRequestHeader :: HeaderName
paddingTypeRequestHeader = HeaderName
"Padding-Type-Request"

paddingTypeReplyHeader :: HT.HeaderName
paddingTypeReplyHeader :: HeaderName
paddingTypeReplyHeader = HeaderName
"Padding-Type-Reply"

type PaddingConduit = ConduitT BS.ByteString BS.ByteString IO ()

noPaddingConduit :: PaddingConduit
noPaddingConduit :: PaddingConduit
noPaddingConduit = forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield

addPaddingConduit :: PaddingType -> PaddingConduit
addPaddingConduit :: PaddingType -> PaddingConduit
addPaddingConduit PaddingType
NoPadding = PaddingConduit
noPaddingConduit
addPaddingConduit PaddingType
Variant1  = Int -> PaddingConduit
addPaddingVariant1 Int
countPaddingsVariant1

removePaddingConduit :: PaddingType -> PaddingConduit
removePaddingConduit :: PaddingType -> PaddingConduit
removePaddingConduit PaddingType
NoPadding = PaddingConduit
noPaddingConduit
removePaddingConduit PaddingType
Variant1  = Int -> PaddingConduit
removePaddingVariant1 Int
countPaddingsVariant1

parseRequestForPadding :: Request -> Maybe PaddingType
parseRequestForPadding :: Request -> Maybe PaddingType
parseRequestForPadding Request
req
    | Just ByteString
paddingTypesStr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
paddingTypeRequestHeader (Request -> RequestHeaders
requestHeaders Request
req) =
        let paddings :: [PaddingType]
paddings = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ByteString -> Maybe PaddingType
parsePaddingType forall a b. (a -> b) -> a -> b
$ Char -> ByteString -> [ByteString]
BS8.split Char
',' ByteString
paddingTypesStr
        in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PaddingType]
paddings then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just (forall a. [a] -> a
head [PaddingType]
paddings)
    | Just ByteString
_ <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
legacyPaddingHeader (Request -> RequestHeaders
requestHeaders Request
req) = forall a. a -> Maybe a
Just PaddingType
Variant1
    | Bool
otherwise                                                 = forall a. Maybe a
Nothing

prepareResponseForPadding :: Maybe PaddingType -> IO [HT.Header]
prepareResponseForPadding :: Maybe PaddingType -> IO RequestHeaders
prepareResponseForPadding Maybe PaddingType
Nothing = forall (m :: * -> *) a. Monad m => a -> m a
return []
prepareResponseForPadding (Just PaddingType
paddingType) = do
    ByteString
rndPadding <- IO ByteString
randomPadding
    forall (m :: * -> *) a. Monad m => a -> m a
return [(HeaderName
legacyPaddingHeader, ByteString
rndPadding), (HeaderName
paddingTypeReplyHeader, PaddingType -> ByteString
showPaddingType PaddingType
paddingType)]

-- see: https://github.com/klzgrad/naiveproxy/blob/master/src/net/tools/naive/naive_protocol.h#L34
countPaddingsVariant1 :: Int
countPaddingsVariant1 :: Int
countPaddingsVariant1 = Int
8

addPaddingVariant1 :: Int -> PaddingConduit
addPaddingVariant1 :: Int -> PaddingConduit
addPaddingVariant1 Int
0 = PaddingConduit
noPaddingConduit
addPaddingVariant1 Int
n = do
    Maybe ByteString
mbs <- forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
    case Maybe ByteString
mbs of
        Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just ByteString
bs | ByteString -> Bool
BS.null ByteString
bs -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just ByteString
bs -> do
            let remaining :: Int
remaining = forall a. Ord a => a -> a -> a
min (ByteString -> Int
BS.length ByteString
bs) (Int
65535 forall a. Num a => a -> a -> a
- Int
3 forall a. Num a => a -> a -> a
- Int
255)
            Int
toConsume <- if Int
remaining forall a. Ord a => a -> a -> Bool
> Int
400 Bool -> Bool -> Bool
&& Int
remaining forall a. Ord a => a -> a -> Bool
< Int
1024
                         then forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Int -> Int -> IO Int
randInt Int
200 Int
300
                         else forall (m :: * -> *) a. Monad m => a -> m a
return Int
remaining
            let (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
toConsume ByteString
bs
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs1) forall a b. (a -> b) -> a -> b
$ forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
bs1
            let len :: Int
len = ByteString -> Int
BS.length ByteString
bs0
                minPaddingLen :: Int
minPaddingLen = if Int
len forall a. Ord a => a -> a -> Bool
< Int
100 then Int
255 forall a. Num a => a -> a -> a
- Int
len else Int
1
            Int
paddingLen <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Int -> Int -> IO Int
randInt Int
minPaddingLen Int
255
            let header :: Builder
header = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> Builder
BB.singletonforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (Integral a, Num b) => a -> b
fromIntegral) [Int
len forall a. Integral a => a -> a -> a
`div` Int
256, Int
len forall a. Integral a => a -> a -> a
`mod` Int
256, Int
paddingLen])
                body :: Builder
body   = ByteString -> Builder
BB.fromByteString ByteString
bs0
                tailer :: Builder
tailer = ByteString -> Builder
BB.fromByteString (Int -> Word8 -> ByteString
BS.replicate Int
paddingLen Word8
0)
            forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString (Builder
header forall a. Semigroup a => a -> a -> a
<> Builder
body forall a. Semigroup a => a -> a -> a
<> Builder
tailer)
            Int -> PaddingConduit
addPaddingVariant1 (Int
n forall a. Num a => a -> a -> a
- Int
1)

removePaddingVariant1 :: Int -> PaddingConduit
removePaddingVariant1 :: Int -> PaddingConduit
removePaddingVariant1 Int
0 = PaddingConduit
noPaddingConduit
removePaddingVariant1 Int
n = do
    ByteString
header <- forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
3
    case ByteString -> [Word8]
LBS.unpack ByteString
header of
        [Word8
b0, Word8
b1, Word8
b2] -> do
            let len :: Int64
len = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b0 forall a. Num a => a -> a -> a
* Int64
256 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b1
                paddingLen :: Int64
paddingLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b2
            ByteString
bs <- forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
len forall a. Num a => a -> a -> a
+ Int64
paddingLen))
            if ByteString -> Int64
LBS.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int64
len forall a. Num a => a -> a -> a
+ Int64
paddingLen
                then forall (m :: * -> *) a. Monad m => a -> m a
return ()
                else forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
LBS.take Int64
len ByteString
bs) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> PaddingConduit
removePaddingVariant1 (Int
n forall a. Num a => a -> a -> a
- Int
1)
        [Word8]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()