{-# LANGUAGE BangPatterns #-}

module Network.Wai.Handler.Warp.RequestHeader (
      parseHeaderLines
    ) where

import UnliftIO (throwIO)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C8 (unpack)
import Data.ByteString.Internal (memchr)
import qualified Data.CaseInsensitive as CI
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, plusPtr, minusPtr, nullPtr)
import Foreign.Storable (peek)
import qualified Network.HTTP.Types as H

import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

-- $setup
-- >>> :set -XOverloadedStrings

----------------------------------------------------------------

parseHeaderLines :: [ByteString]
                 -> IO (H.Method
                       ,ByteString  --  Path
                       ,ByteString  --  Path, parsed
                       ,ByteString  --  Query
                       ,H.HttpVersion
                       ,H.RequestHeaders
                       )
parseHeaderLines :: [ByteString]
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
parseHeaderLines [] = InvalidRequest
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidRequest
 -> IO
      (ByteString, ByteString, ByteString, ByteString, HttpVersion,
       RequestHeaders))
-> InvalidRequest
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall a b. (a -> b) -> a -> b
$ [String] -> InvalidRequest
NotEnoughLines []
parseHeaderLines (ByteString
firstLine:[ByteString]
otherLines) = do
    (ByteString
method, ByteString
path', ByteString
query, HttpVersion
httpversion) <- ByteString -> IO (ByteString, ByteString, ByteString, HttpVersion)
parseRequestLine ByteString
firstLine
    let path :: ByteString
path = ByteString -> ByteString
H.extractPath ByteString
path'
        hdr :: RequestHeaders
hdr = (ByteString -> Header) -> [ByteString] -> RequestHeaders
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Header
parseHeader [ByteString]
otherLines
    (ByteString, ByteString, ByteString, ByteString, HttpVersion,
 RequestHeaders)
-> IO
     (ByteString, ByteString, ByteString, ByteString, HttpVersion,
      RequestHeaders)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
method, ByteString
path', ByteString
path, ByteString
query, HttpVersion
httpversion, RequestHeaders
hdr)

----------------------------------------------------------------

-- |
--
-- >>> parseRequestLine "GET / HTTP/1.1"
-- ("GET","/","",HTTP/1.1)
-- >>> parseRequestLine "POST /cgi/search.cgi?key=foo HTTP/1.0"
-- ("POST","/cgi/search.cgi","?key=foo",HTTP/1.0)
-- >>> parseRequestLine "GET "
-- *** Exception: Warp: Invalid first line of request: "GET "
-- >>> parseRequestLine "GET /NotHTTP UNKNOWN/1.1"
-- *** Exception: Warp: Request line specified a non-HTTP request
-- >>> parseRequestLine "PRI * HTTP/2.0"
-- ("PRI","*","",HTTP/2.0)
parseRequestLine :: ByteString
                 -> IO (H.Method
                       ,ByteString -- Path
                       ,ByteString -- Query
                       ,H.HttpVersion)
parseRequestLine :: ByteString -> IO (ByteString, ByteString, ByteString, HttpVersion)
parseRequestLine requestLine :: ByteString
requestLine@(PS ForeignPtr Word8
fptr Int
off Int
len) = ForeignPtr Word8
-> (Ptr Word8
    -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8
  -> IO (ByteString, ByteString, ByteString, HttpVersion))
 -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> (Ptr Word8
    -> IO (ByteString, ByteString, ByteString, HttpVersion))
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
14) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let methodptr :: Ptr b
methodptr = Ptr Word8
ptr Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
        limptr :: Ptr b
limptr = Ptr Any
forall b. Ptr b
methodptr Ptr Any -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len
        lim0 :: CSize
lim0 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len

    Ptr Word8
pathptr0 <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall b. Ptr b
methodptr Word8
32 CSize
lim0 -- ' '
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr Word8
pathptr0 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall b. Ptr b
nullPtr Bool -> Bool -> Bool
|| (Ptr Any
forall b. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
pathptr0) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
11) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let pathptr :: Ptr b
pathptr = Ptr Word8
pathptr0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
        lim1 :: CSize
lim1 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Ptr Any
forall b. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
pathptr0)

    Ptr Word8
httpptr0 <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall b. Ptr b
pathptr Word8
32 CSize
lim1 -- ' '
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr Word8
httpptr0 Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall b. Ptr b
nullPtr Bool -> Bool -> Bool
|| (Ptr Any
forall b. Ptr b
limptr Ptr Any -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
httpptr0) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
9) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
baderr
    let httpptr :: Ptr b
httpptr = Ptr Word8
httpptr0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
        lim2 :: CSize
lim2 = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Ptr Word8
httpptr0 Ptr Word8 -> Ptr Any -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Any
forall b. Ptr b
pathptr)

    Ptr Word8 -> IO ()
checkHTTP Ptr Word8
forall b. Ptr b
httpptr
    !HttpVersion
hv <- Ptr Any -> IO HttpVersion
forall a. Ptr a -> IO HttpVersion
httpVersion Ptr Any
forall b. Ptr b
httpptr
    Ptr Word8
queryptr <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
memchr Ptr Word8
forall b. Ptr b
pathptr Word8
63 CSize
lim2 -- '?'

    let !method :: ByteString
method = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall b a a. Ptr b -> Ptr a -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall b. Ptr b
methodptr Ptr Word8
pathptr0
        !path :: ByteString
path
          | Ptr Word8
queryptr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall b. Ptr b
nullPtr = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall b a a. Ptr b -> Ptr a -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall b. Ptr b
pathptr Ptr Word8
httpptr0
          | Bool
otherwise           = Ptr Word8 -> Ptr Any -> Ptr Word8 -> ByteString
forall b a a. Ptr b -> Ptr a -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Any
forall b. Ptr b
pathptr Ptr Word8
queryptr
        !query :: ByteString
query
          | Ptr Word8
queryptr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Word8
forall b. Ptr b
nullPtr = ByteString
S.empty
          | Bool
otherwise           = Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> ByteString
forall b a a. Ptr b -> Ptr a -> Ptr a -> ByteString
bs Ptr Word8
ptr Ptr Word8
queryptr Ptr Word8
httpptr0

    (ByteString, ByteString, ByteString, HttpVersion)
-> IO (ByteString, ByteString, ByteString, HttpVersion)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
method,ByteString
path,ByteString
query,HttpVersion
hv)
  where
    baderr :: InvalidRequest
baderr = String -> InvalidRequest
BadFirstLine (String -> InvalidRequest) -> String -> InvalidRequest
forall a b. (a -> b) -> a -> b
$ ByteString -> String
C8.unpack ByteString
requestLine
    check :: Ptr Word8 -> Int -> Word8 -> IO ()
    check :: Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
p Int
n Word8
w = do
        Word8
w0 <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word8 -> IO Word8) -> Ptr Word8 -> IO Word8
forall a b. (a -> b) -> a -> b
$ Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
w0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
w) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO InvalidRequest
NonHttp
    checkHTTP :: Ptr Word8 -> IO ()
checkHTTP Ptr Word8
httpptr = do
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
0 Word8
72 -- 'H'
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
1 Word8
84 -- 'T'
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
2 Word8
84 -- 'T'
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
3 Word8
80 -- 'P'
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
4 Word8
47 -- '/'
        Ptr Word8 -> Int -> Word8 -> IO ()
check Ptr Word8
httpptr Int
6 Word8
46 -- '.'
    httpVersion :: Ptr a -> IO HttpVersion
httpVersion Ptr a
httpptr = do
        Word8
major <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr a
httpptr Ptr a -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
5) :: IO Word8
        Word8
minor <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr a
httpptr Ptr a -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
7) :: IO Word8
        let version :: HttpVersion
version
              | Word8
major Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
49 = if Word8
minor Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
49 then HttpVersion
H.http11 else HttpVersion
H.http10
              | Word8
major Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
50 Bool -> Bool -> Bool
&& Word8
minor Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
48 = Int -> Int -> HttpVersion
H.HttpVersion Int
2 Int
0
              | Bool
otherwise   = HttpVersion
H.http10
        HttpVersion -> IO HttpVersion
forall (m :: * -> *) a. Monad m => a -> m a
return HttpVersion
version
    bs :: Ptr b -> Ptr a -> Ptr a -> ByteString
bs Ptr b
ptr Ptr a
p0 Ptr a
p1 = ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
fptr Int
o Int
l
      where
        o :: Int
o = Ptr a
p0 Ptr a -> Ptr b -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr b
ptr
        l :: Int
l = Ptr a
p1 Ptr a -> Ptr a -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr a
p0

----------------------------------------------------------------

-- |
--
-- >>> parseHeader "Content-Length:47"
-- ("Content-Length","47")
-- >>> parseHeader "Accept-Ranges: bytes"
-- ("Accept-Ranges","bytes")
-- >>> parseHeader "Host:  example.com:8080"
-- ("Host","example.com:8080")
-- >>> parseHeader "NoSemiColon"
-- ("NoSemiColon","")

parseHeader :: ByteString -> H.Header
parseHeader :: ByteString -> Header
parseHeader ByteString
s =
    let (ByteString
k, ByteString
rest) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
58) ByteString
s -- ':'
        rest' :: ByteString
rest' = (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile (\Word8
c -> Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
32 Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
9) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
S.drop Int
1 ByteString
rest
     in (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk ByteString
k, ByteString
rest')