-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Util
  ( parseHostPort
  , parseHostPortWithDefault
  , responseKnownLength
  ) where

import Data.ByteString       qualified as BS
import Data.ByteString.Char8 qualified as BS8
import Data.ByteString.Lazy  qualified as LBS
import Data.Maybe            (fromMaybe)

import Network.HTTP.Types (ResponseHeaders, Status)
import Network.Wai

parseHostPort :: BS.ByteString -> Maybe (BS.ByteString, Int)
parseHostPort :: ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort = do
    Int
lastColon <- Char -> ByteString -> Maybe Int
BS8.elemIndexEnd Char
':' ByteString
hostPort
    Int
port <- ByteString -> Maybe (Int, ByteString)
BS8.readInt (Int -> ByteString -> ByteString
BS.drop (Int
lastColonforall a. Num a => a -> a -> a
+Int
1) ByteString
hostPort) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {a}. (Ord a, Num a) => (a, ByteString) -> Maybe a
checkPort
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> ByteString -> ByteString
BS.take Int
lastColon ByteString
hostPort, Int
port)
  where
    checkPort :: (a, ByteString) -> Maybe a
checkPort (a
p, ByteString
bs)
        | ByteString -> Bool
BS.null ByteString
bs Bool -> Bool -> Bool
&& a
1 forall a. Ord a => a -> a -> Bool
<= a
p Bool -> Bool -> Bool
&& a
p forall a. Ord a => a -> a -> Bool
<= a
65535 = forall a. a -> Maybe a
Just a
p
        | Bool
otherwise                          = forall a. Maybe a
Nothing

parseHostPortWithDefault :: Int -> BS.ByteString -> (BS.ByteString, Int)
parseHostPortWithDefault :: Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
defaultPort ByteString
hostPort =
    forall a. a -> Maybe a -> a
fromMaybe (ByteString
hostPort, Int
defaultPort) forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort

responseKnownLength :: Status -> ResponseHeaders -> LBS.ByteString -> Response
responseKnownLength :: Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength Status
status ResponseHeaders
headers ByteString
bs = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status (ResponseHeaders
headers forall a. [a] -> [a] -> [a]
++ [(HeaderName
"Content-Length", String -> ByteString
BS8.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show (ByteString -> Int64
LBS.length ByteString
bs))]) ByteString
bs