{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

module Http.Header
  ( Header (..)
  , decodeMany
  , parser
  , parserSmallArray
  , builder
  , builderSmallArray
  ) where

import Data.Bytes (Bytes)
import Data.Bytes.Builder (Builder)
import Data.Bytes.Parser (Parser)
import Data.Bytes.Types (Bytes (Bytes))
import Data.Primitive (ByteArray (ByteArray), SmallArray, SmallMutableArray)
import Data.Text (Text)
import Data.Word (Word8)

import Data.Bytes qualified as Bytes
import Data.Bytes.Builder qualified as Builder
import Data.Bytes.Parser qualified as Parser
import Data.Bytes.Parser.Latin qualified as Latin
import Data.Bytes.Text.Utf8 qualified as Utf8
import Data.Primitive qualified as PM
import Data.Text.Array qualified
import Data.Text.Internal qualified as Text

{- | An HTTP header. This type does not enforce a restricted character
set. If, for example, the user creates a header whose key has a colon
character, the resulting request will be malformed.
-}
data Header = Header
  { Header -> Text
name :: {-# UNPACK #-} !Text
  , Header -> Text
value :: {-# UNPACK #-} !Text
  }
  deriving (Header -> Header -> Bool
(Header -> Header -> Bool)
-> (Header -> Header -> Bool) -> Eq Header
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Header -> Header -> Bool
== :: Header -> Header -> Bool
$c/= :: Header -> Header -> Bool
/= :: Header -> Header -> Bool
Eq, Int -> Header -> ShowS
[Header] -> ShowS
Header -> String
(Int -> Header -> ShowS)
-> (Header -> String) -> ([Header] -> ShowS) -> Show Header
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Header -> ShowS
showsPrec :: Int -> Header -> ShowS
$cshow :: Header -> String
show :: Header -> String
$cshowList :: [Header] -> ShowS
showList :: [Header] -> ShowS
Show)

uninitializedHeader :: Header
{-# NOINLINE uninitializedHeader #-}
uninitializedHeader :: Header
uninitializedHeader = String -> Header
forall a. String -> a
errorWithoutStackTrace String
"parserHeaders: uninitialized header"

{- | Parse headers. Expects two CRLF sequences in a row at the end.
Fails if leftovers are encountered.
-}
decodeMany :: Int -> Bytes -> Maybe (SmallArray Header)
decodeMany :: Int -> Bytes -> Maybe (SmallArray Header)
decodeMany !Int
n !Bytes
b = (forall s. Parser () s (SmallArray Header))
-> Bytes -> Maybe (SmallArray Header)
forall e a. (forall s. Parser e s a) -> Bytes -> Maybe a
Parser.parseBytesMaybe (Int -> Parser () s (SmallArray Header)
forall s. Int -> Parser () s (SmallArray Header)
parserSmallArray Int
n Parser () s (SmallArray Header)
-> Parser () s () -> Parser () s (SmallArray Header)
forall a b. Parser () s a -> Parser () s b -> Parser () s a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* () -> Parser () s ()
forall e s. e -> Parser e s ()
Parser.endOfInput ()) Bytes
b

-- Parse headers. Stops after encountering two CRLF sequences in
-- a row.
parserSmallArray ::
  Int -> -- maximum number of headers allowed, recommended 128
  Parser () s (SmallArray Header)
parserSmallArray :: forall s. Int -> Parser () s (SmallArray Header)
parserSmallArray !Int
n = do
  SmallMutableArray s Header
dst <- ST s (SmallMutableArray s Header)
-> Parser () s (SmallMutableArray s Header)
forall s a e. ST s a -> Parser e s a
Parser.effect (Int -> Header -> ST s (SmallMutableArray (PrimState (ST s)) Header)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
PM.newSmallArray Int
n Header
uninitializedHeader)
  Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep Int
0 Int
n SmallMutableArray s Header
dst

parserHeaderStep ::
  Int -> -- index
  Int -> -- remaining length
  SmallMutableArray s Header ->
  Parser () s (SmallArray Header)
parserHeaderStep :: forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep !Int
ix !Int
n !SmallMutableArray s Header
dst =
  (Char -> Bool) -> Parser () s Bool
forall e s. (Char -> Bool) -> Parser e s Bool
Latin.trySatisfy (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\r') Parser () s Bool
-> (Bool -> Parser () s (SmallArray Header))
-> Parser () s (SmallArray Header)
forall a b. Parser () s a -> (a -> Parser () s b) -> Parser () s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> do
      () -> Char -> Parser () s ()
forall e s. e -> Char -> Parser e s ()
Latin.char () Char
'\n'
      ST s (SmallArray Header) -> Parser () s (SmallArray Header)
forall s a e. ST s a -> Parser e s a
Parser.effect (ST s (SmallArray Header) -> Parser () s (SmallArray Header))
-> ST s (SmallArray Header) -> Parser () s (SmallArray Header)
forall a b. (a -> b) -> a -> b
$ do
        SmallMutableArray (PrimState (ST s)) Header -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m ()
PM.shrinkSmallMutableArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst Int
ix
        SmallMutableArray (PrimState (ST s)) Header
-> ST s (SmallArray Header)
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> m (SmallArray a)
PM.unsafeFreezeSmallArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst
    Bool
False ->
      if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        then do
          Header
header <- Parser () s Header
forall s. Parser () s Header
parser
          ST s () -> Parser () s ()
forall s a e. ST s a -> Parser e s a
Parser.effect (SmallMutableArray (PrimState (ST s)) Header
-> Int -> Header -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
PM.writeSmallArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst Int
ix Header
header)
          Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) SmallMutableArray s Header
dst
        else () -> Parser () s (SmallArray Header)
forall e s a. e -> Parser e s a
Parser.fail ()

pattern Bang :: Word8
pattern $mBang :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bBang :: Word8
Bang = 0x21

pattern Pound :: Word8
pattern $mPound :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bPound :: Word8
Pound = 0x23

pattern Dollar :: Word8
pattern $mDollar :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bDollar :: Word8
Dollar = 0x24

pattern Percent :: Word8
pattern $mPercent :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bPercent :: Word8
Percent = 0x25

pattern Ampersand :: Word8
pattern $mAmpersand :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bAmpersand :: Word8
Ampersand = 0x26

pattern SingleQuote :: Word8
pattern $mSingleQuote :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bSingleQuote :: Word8
SingleQuote = 0x27

pattern Asterisk :: Word8
pattern $mAsterisk :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bAsterisk :: Word8
Asterisk = 0x2A

pattern Plus :: Word8
pattern $mPlus :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bPlus :: Word8
Plus = 0x2B

pattern Hyphen :: Word8
pattern $mHyphen :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bHyphen :: Word8
Hyphen = 0x2D

pattern Period :: Word8
pattern $mPeriod :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bPeriod :: Word8
Period = 0x2E

pattern Caret :: Word8
pattern $mCaret :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bCaret :: Word8
Caret = 0x5E

pattern Underscore :: Word8
pattern $mUnderscore :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bUnderscore :: Word8
Underscore = 0x5F

pattern Backtick :: Word8
pattern $mBacktick :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bBacktick :: Word8
Backtick = 0x60

pattern Pipe :: Word8
pattern $mPipe :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bPipe :: Word8
Pipe = 0x7C

pattern Twiddle :: Word8
pattern $mTwiddle :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bTwiddle :: Word8
Twiddle = 0x7E

pattern HorizontalTab :: Word8
pattern $mHorizontalTab :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bHorizontalTab :: Word8
HorizontalTab = 0x09

{- | Parse a single HTTP header including the trailing CRLF sequence.
From RFC 7230:

> token          = 1*tchar
> tchar          = "!" / "#" / "$" / "%" / "&" / "'" / "*"
>                / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" 
>                / DIGIT / ALPHA
> 
> header-field   = field-name ":" OWS field-value OWS
> field-name     = token
> field-value    = *( field-content / obs-fold )
> field-content  = field-vchar [ 1*( SP / HTAB ) field-vchar ]
> field-vchar    = VCHAR / obs-text
-}
parser :: Parser () s Header
parser :: forall s. Parser () s Header
parser = do
  -- Header name may contain: a-z, A-Z, 0-9, several different symbols
  !Bytes
name <- (Word8 -> Bool) -> Parser () s Bytes
forall e s. (Word8 -> Bool) -> Parser e s Bytes
Parser.takeWhile ((Word8 -> Bool) -> Parser () s Bytes)
-> (Word8 -> Bool) -> Parser () s Bytes
forall a b. (a -> b) -> a -> b
$ \Word8
c ->
    (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x41 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x5A)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x61 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x7A)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x30 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x39)
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Bang
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Pound
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Dollar
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Percent
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Ampersand
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
SingleQuote
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Asterisk
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Plus
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Hyphen
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Period
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Caret
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Underscore
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Backtick
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Pipe
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
Twiddle
  () -> Char -> Parser () s ()
forall e s. e -> Char -> Parser e s ()
Latin.char () Char
':'
  (Char -> Bool) -> Parser () s ()
forall e s. (Char -> Bool) -> Parser e s ()
Latin.skipWhile (\Char
c -> Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ' Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\t')
  -- Header value allows vchar, space, and tab.
  Bytes
value0 <- (Word8 -> Bool) -> Parser () s Bytes
forall e s. (Word8 -> Bool) -> Parser e s Bytes
Parser.takeWhile ((Word8 -> Bool) -> Parser () s Bytes)
-> (Word8 -> Bool) -> Parser () s Bytes
forall a b. (a -> b) -> a -> b
$ \Word8
c ->
    (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x20 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x7e)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
HorizontalTab)
  () -> Char -> Char -> Parser () s ()
forall e s. e -> Char -> Char -> Parser e s ()
Latin.char2 () Char
'\r' Char
'\n'
  -- We only need to trim the end because the leading spaces and tab
  -- were already skipped.
  let !value :: Bytes
value = (Word8 -> Bool) -> Bytes -> Bytes
Bytes.dropWhileEnd (\Word8
c -> Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x20 Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x09) Bytes
value0
  Header -> Parser () s Header
forall a. a -> Parser () s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Header {name :: Text
name = Bytes -> Text
unsafeBytesToText Bytes
name, value :: Text
value = Bytes -> Text
unsafeBytesToText Bytes
value}

unsafeBytesToText :: Bytes -> Text
{-# INLINE unsafeBytesToText #-}
unsafeBytesToText :: Bytes -> Text
unsafeBytesToText (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) =
  Array -> Int -> Int -> Text
Text.Text (ByteArray# -> Array
Data.Text.Array.ByteArray ByteArray#
arr) Int
off Int
len

-- | Encode a header. Includes the trailing CRLF sequence.
builder :: Header -> Builder
builder :: Header -> Builder
builder Header {Text
name :: Header -> Text
name :: Text
name, Text
value :: Header -> Text
value :: Text
value} =
  Bytes -> Builder
Builder.copy (Text -> Bytes
Utf8.fromText Text
name)
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Char -> Char -> Builder
Builder.ascii2 Char
':' Char
' '
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Bytes -> Builder
Builder.copy (Text -> Bytes
Utf8.fromText Text
value)
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Char -> Char -> Builder
Builder.ascii2 Char
'\r' Char
'\n'

builderSmallArray :: SmallArray Header -> Builder
builderSmallArray :: SmallArray Header -> Builder
builderSmallArray = (Header -> Builder) -> SmallArray Header -> Builder
forall m a. Monoid m => (a -> m) -> SmallArray a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Header -> Builder
builder