{-# LANGUAGE BangPatterns #-}

module Network.Wai.Handler.Warp.RequestHeader (
    parseHeaderLines,
) where

import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C8 (unpack)
import Data.ByteString.Internal (memchr)
import qualified Data.CaseInsensitive as CI
import Data.Word8
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, minusPtr, nullPtr, plusPtr)
import Foreign.Storable (peek)
import qualified Network.HTTP.Types as H
import UnliftIO (throwIO)

import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

-- $setup
-- >>> :set -XOverloadedStrings

----------------------------------------------------------------

parseHeaderLines
    :: [ByteString]
    -> IO
        ( H.Method
        , ByteString --  Path
        , ByteString --  Path, parsed
        , ByteString --  Query
        , H.HttpVersion
        , H.RequestHeaders
        )
parseHeaderLines :: [ByteString]
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
parseHeaderLines [] = InvalidRequest
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidRequest
 -> IO
      (ByteString, ByteString, ByteString, ByteString, HttpVersion,
       RequestHeaders))
-> InvalidRequest
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall a b. (a -> b) -> a -> b
$ [String] -> InvalidRequest
NotEnoughLines []
parseHeaderLines (ByteString
firstLine : [ByteString]
otherLines) = do
    (ByteString
method, ByteString
path', ByteString
query, HttpVersion
httpversion) <- ByteString -> IO (ByteString, ByteString, ByteString, HttpVersion)
parseRequestLine ByteString
firstLine
    let path :: ByteString
path = ByteString -> ByteString
H.extractPath ByteString
path'
        hdr :: RequestHeaders
hdr = (ByteString -> Header) -> [ByteString] -> RequestHeaders
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Header
parseHeader [ByteString]
otherLines
    (ByteString, ByteString, ByteString, ByteString, HttpVersion,
 RequestHeaders)
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
method, ByteString
path', ByteString
path, ByteString
query, HttpVersion
httpversion, RequestHeaders
hdr)

----------------------------------------------------------------

-- |
--
-- >>> parseRequestLine "GET / HTTP/1.1"
-- ("GET","/","",HTTP/1.1)
-- >>> parseRequestLine "POST /cgi/search.cgi?key=foo HTTP/1.0"
-- ("POST","/cgi/search.cgi","?key=foo",HTTP/1.0)
-- >>> parseRequestLine "GET "
-- *** Exception: Warp: Invalid first line of request: "GET "
-- >>> parseRequestLine "GET /NotHTTP UNKNOWN/1.1"
-- *** Exception: Warp: Request line specified a non-HTTP request
-- >>> parseRequestLine "PRI * HTTP/2.0"
-- ("PRI","*","",HTTP/2.0)
parseRequestLine
    :: ByteString
    -> IO
        ( H.Method
        , ByteString -- Path
        , ByteString -- Query
        , H.HttpVersion
        )
parseRequestLine :: ByteString -> IO (ByteString, ByteString, ByteString, HttpVersion)
parseRequestLine requestLine :: ByteString
requestLine@(PS ForeignPtr Word8
fptr Int
off Int
len) = ForeignPtr Word8
-> (Ptr Word8
    -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8
  -> IO (ByteString, ByteString, ByteString, HttpVersion))
 -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> (Ptr Word8
    -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
14) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let methodptr :: Ptr b
methodptr = Ptr Word8
ptr Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
        limptr :: Ptr b
limptr = Ptr Any
forall {b}. Ptr b
methodptr Ptr Any -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len
        lim0 :: CSize
lim0 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len

    Ptr Word8
pathptr0 <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall {b}. Ptr b
methodptr Word8
_space CSize
lim0
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr Word8
pathptr0 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall {b}. Ptr b
nullPtr Bool -> Bool -> Bool
|| (Ptr Any
forall {b}. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
pathptr0) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
11) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let pathptr :: Ptr b
pathptr = Ptr Word8
pathptr0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
        lim1 :: CSize
lim1 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Ptr Any
forall {b}. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
pathptr0)

    Ptr Word8
httpptr0 <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall {b}. Ptr b
pathptr Word8
_space CSize
lim1
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr Word8
httpptr0 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall {b}. Ptr b
nullPtr Bool -> Bool -> Bool
|| (Ptr Any
forall {b}. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
httpptr0) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
9) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let httpptr :: Ptr b
httpptr = Ptr Word8
httpptr0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
        lim2 :: CSize
lim2 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Ptr Word8
httpptr0 Ptr Word8 -> Ptr Any -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Any
forall {b}. Ptr b
pathptr)

    Ptr Word8 -> IO ()
checkHTTP Ptr Word8
forall {b}. Ptr b
httpptr
    !HttpVersion
hv <- Ptr Any -> IO HttpVersion
forall {a}. Ptr a -> IO HttpVersion
httpVersion Ptr Any
forall {b}. Ptr b
httpptr
    Ptr Word8
queryptr <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall {b}. Ptr b
pathptr Word8
_question CSize
lim2

    let !method :: ByteString
method = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall {b} {b} {a}. Ptr b -> Ptr b -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall {b}. Ptr b
methodptr Ptr Word8
pathptr0
        !path :: ByteString
path
            | Ptr Word8
queryptr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall {b}. Ptr b
nullPtr = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall {b} {b} {a}. Ptr b -> Ptr b -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall {b}. Ptr b
pathptr Ptr Word8
httpptr0
            | Bool
otherwise = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall {b} {b} {a}. Ptr b -> Ptr b -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall {b}. Ptr b
pathptr Ptr Word8
queryptr
        !query :: ByteString
query
            | Ptr Word8
queryptr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall {b}. Ptr b
nullPtr = ByteString
S.empty
            | Bool
otherwise = Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> ByteString
forall {b} {b} {a}. Ptr b -> Ptr b -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Word8
queryptr Ptr Word8
httpptr0

    (ByteString, ByteString, ByteString, HttpVersion)
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
method, ByteString
path, ByteString
query, HttpVersion
hv)
  where
    baderr :: InvalidRequest
baderr = String -> InvalidRequest
BadFirstLine (String -> InvalidRequest) -> String -> InvalidRequest
forall a b. (a -> b) -> a -> b
$ ByteString -> String
C8.unpack ByteString
requestLine
    check :: Ptr Word8 -> Int -> Word8 -> IO ()
    check :: Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
p Int
n Word8
w = do
        Word8
w0 <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word8 -> IO Word8) -> Ptr Word8 -> IO Word8
forall a b. (a -> b) -> a -> b
$ Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
w0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
w) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
NonHttp
    checkHTTP :: Ptr Word8 -> IO ()
checkHTTP Ptr Word8
httpptr = do
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
0 Word8
_H
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
1 Word8
_T
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
2 Word8
_T
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
3 Word8
_P
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
4 Word8
_slash
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
6 Word8
_period
    httpVersion :: Ptr a -> IO HttpVersion
httpVersion Ptr a
httpptr = do
        Word8
major <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr a
httpptr Ptr a -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
5) :: IO Word8
        Word8
minor <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr a
httpptr Ptr a -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
7) :: IO Word8
        let version :: HttpVersion
version
                | Word8
major Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_1 = if Word8
minor Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_1 then HttpVersion
H.http11 else HttpVersion
H.http10
                | Word8
major Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_2 Bool -> Bool -> Bool
&& Word8
minor Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_0 = HttpVersion
H.http20
                | Bool
otherwise = HttpVersion
H.http10
        HttpVersion -> IO HttpVersion
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return HttpVersion
version
    bs :: Ptr b -> Ptr b -> Ptr a -> ByteString
bs Ptr b
ptr Ptr b
p0 Ptr a
p1 = ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
fptr Int
o Int
l
      where
        o :: Int
o = Ptr b
p0 Ptr b -> Ptr b -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr b
ptr
        l :: Int
l = Ptr a
p1 Ptr a -> Ptr b -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr b
p0

----------------------------------------------------------------

-- |
--
-- >>> parseHeader "Content-Length:47"
-- ("Content-Length","47")
-- >>> parseHeader "Accept-Ranges: bytes"
-- ("Accept-Ranges","bytes")
-- >>> parseHeader "Host:  example.com:8080"
-- ("Host","example.com:8080")
-- >>> parseHeader "NoSemiColon"
-- ("NoSemiColon","")
parseHeader :: ByteString -> H.Header
parseHeader :: ByteString -> Header
parseHeader ByteString
s =
    let (ByteString
k, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_colon) ByteString
s
        rest' :: ByteString
rest' = (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile (\Word8
c -> Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_space Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_tab) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
S.drop Int
1 ByteString
rest
     in (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk ByteString
k, ByteString
rest')