{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.Wai.Handler.Warp.Header where

import Data.Array
import Data.Array.ST
import qualified Data.ByteString as BS
import Data.CaseInsensitive (foldedCase)
import Network.HTTP.Types

import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

-- | Array for a set of HTTP headers.
type IndexedHeader = Array Int (Maybe HeaderValue)

----------------------------------------------------------------

indexRequestHeader :: RequestHeaders -> IndexedHeader
indexRequestHeader :: RequestHeaders -> IndexedHeader
indexRequestHeader RequestHeaders
hdr = RequestHeaders -> Int -> (HeaderName -> Int) -> IndexedHeader
traverseHeader RequestHeaders
hdr Int
requestMaxIndex HeaderName -> Int
requestKeyIndex

data RequestHeaderIndex
    = ReqContentLength
    | ReqTransferEncoding
    | ReqExpect
    | ReqConnection
    | ReqRange
    | ReqHost
    | ReqIfModifiedSince
    | ReqIfUnmodifiedSince
    | ReqIfRange
    | ReqReferer
    | ReqUserAgent
    | ReqIfMatch
    | ReqIfNoneMatch
    deriving (Int -> RequestHeaderIndex
RequestHeaderIndex -> Int
RequestHeaderIndex -> [RequestHeaderIndex]
RequestHeaderIndex -> RequestHeaderIndex
RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
RequestHeaderIndex
-> RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
(RequestHeaderIndex -> RequestHeaderIndex)
-> (RequestHeaderIndex -> RequestHeaderIndex)
-> (Int -> RequestHeaderIndex)
-> (RequestHeaderIndex -> Int)
-> (RequestHeaderIndex -> [RequestHeaderIndex])
-> (RequestHeaderIndex
    -> RequestHeaderIndex -> [RequestHeaderIndex])
-> (RequestHeaderIndex
    -> RequestHeaderIndex -> [RequestHeaderIndex])
-> (RequestHeaderIndex
    -> RequestHeaderIndex
    -> RequestHeaderIndex
    -> [RequestHeaderIndex])
-> Enum RequestHeaderIndex
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: RequestHeaderIndex -> RequestHeaderIndex
succ :: RequestHeaderIndex -> RequestHeaderIndex
$cpred :: RequestHeaderIndex -> RequestHeaderIndex
pred :: RequestHeaderIndex -> RequestHeaderIndex
$ctoEnum :: Int -> RequestHeaderIndex
toEnum :: Int -> RequestHeaderIndex
$cfromEnum :: RequestHeaderIndex -> Int
fromEnum :: RequestHeaderIndex -> Int
$cenumFrom :: RequestHeaderIndex -> [RequestHeaderIndex]
enumFrom :: RequestHeaderIndex -> [RequestHeaderIndex]
$cenumFromThen :: RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
enumFromThen :: RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
$cenumFromTo :: RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
enumFromTo :: RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
$cenumFromThenTo :: RequestHeaderIndex
-> RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
enumFromThenTo :: RequestHeaderIndex
-> RequestHeaderIndex -> RequestHeaderIndex -> [RequestHeaderIndex]
Enum, RequestHeaderIndex
RequestHeaderIndex
-> RequestHeaderIndex -> Bounded RequestHeaderIndex
forall a. a -> a -> Bounded a
$cminBound :: RequestHeaderIndex
minBound :: RequestHeaderIndex
$cmaxBound :: RequestHeaderIndex
maxBound :: RequestHeaderIndex
Bounded)

-- | The size for 'IndexedHeader' for HTTP Request.
--   From 0 to this corresponds to:
--
-- - \"Content-Length\"
-- - \"Transfer-Encoding\"
-- - \"Expect\"
-- - \"Connection\"
-- - \"Range\"
-- - \"Host\"
-- - \"If-Modified-Since\"
-- - \"If-Unmodified-Since\"
-- - \"If-Range\"
-- - \"Referer\"
-- - \"User-Agent\"
-- - \"If-Match\"
-- - \"If-None-Match\"
requestMaxIndex :: Int
requestMaxIndex :: Int
requestMaxIndex = RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum (RequestHeaderIndex
forall a. Bounded a => a
maxBound :: RequestHeaderIndex)

requestKeyIndex :: HeaderName -> Int
requestKeyIndex :: HeaderName -> Int
requestKeyIndex HeaderName
hn = case ByteString -> Int
BS.length ByteString
bs of
    Int
4 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"host" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqHost
    Int
5 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"range" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqRange
    Int
6 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"expect" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqExpect
    Int
7 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"referer" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqReferer
    Int
8
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"if-range" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqIfRange
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"if-match" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqIfMatch
    Int
10
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"user-agent" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqUserAgent
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"connection" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqConnection
    Int
13 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"if-none-match" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqIfNoneMatch
    Int
14 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"content-length" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqContentLength
    Int
17
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"transfer-encoding" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqTransferEncoding
        | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"if-modified-since" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqIfModifiedSince
    Int
19 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"if-unmodified-since" -> RequestHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum RequestHeaderIndex
ReqIfUnmodifiedSince
    Int
_ -> -Int
1
  where
    bs :: ByteString
bs = HeaderName -> ByteString
forall s. CI s -> s
foldedCase HeaderName
hn

defaultIndexRequestHeader :: IndexedHeader
defaultIndexRequestHeader :: IndexedHeader
defaultIndexRequestHeader = (Int, Int) -> [(Int, Maybe ByteString)] -> IndexedHeader
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
0, Int
requestMaxIndex) [(Int
i, Maybe ByteString
forall a. Maybe a
Nothing) | Int
i <- [Int
0 .. Int
requestMaxIndex]]

----------------------------------------------------------------

indexResponseHeader :: ResponseHeaders -> IndexedHeader
indexResponseHeader :: RequestHeaders -> IndexedHeader
indexResponseHeader RequestHeaders
hdr = RequestHeaders -> Int -> (HeaderName -> Int) -> IndexedHeader
traverseHeader RequestHeaders
hdr Int
responseMaxIndex HeaderName -> Int
responseKeyIndex

data ResponseHeaderIndex
    = ResContentLength
    | ResServer
    | ResDate
    | ResLastModified
    deriving (Int -> ResponseHeaderIndex
ResponseHeaderIndex -> Int
ResponseHeaderIndex -> [ResponseHeaderIndex]
ResponseHeaderIndex -> ResponseHeaderIndex
ResponseHeaderIndex -> ResponseHeaderIndex -> [ResponseHeaderIndex]
ResponseHeaderIndex
-> ResponseHeaderIndex
-> ResponseHeaderIndex
-> [ResponseHeaderIndex]
(ResponseHeaderIndex -> ResponseHeaderIndex)
-> (ResponseHeaderIndex -> ResponseHeaderIndex)
-> (Int -> ResponseHeaderIndex)
-> (ResponseHeaderIndex -> Int)
-> (ResponseHeaderIndex -> [ResponseHeaderIndex])
-> (ResponseHeaderIndex
    -> ResponseHeaderIndex -> [ResponseHeaderIndex])
-> (ResponseHeaderIndex
    -> ResponseHeaderIndex -> [ResponseHeaderIndex])
-> (ResponseHeaderIndex
    -> ResponseHeaderIndex
    -> ResponseHeaderIndex
    -> [ResponseHeaderIndex])
-> Enum ResponseHeaderIndex
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: ResponseHeaderIndex -> ResponseHeaderIndex
succ :: ResponseHeaderIndex -> ResponseHeaderIndex
$cpred :: ResponseHeaderIndex -> ResponseHeaderIndex
pred :: ResponseHeaderIndex -> ResponseHeaderIndex
$ctoEnum :: Int -> ResponseHeaderIndex
toEnum :: Int -> ResponseHeaderIndex
$cfromEnum :: ResponseHeaderIndex -> Int
fromEnum :: ResponseHeaderIndex -> Int
$cenumFrom :: ResponseHeaderIndex -> [ResponseHeaderIndex]
enumFrom :: ResponseHeaderIndex -> [ResponseHeaderIndex]
$cenumFromThen :: ResponseHeaderIndex -> ResponseHeaderIndex -> [ResponseHeaderIndex]
enumFromThen :: ResponseHeaderIndex -> ResponseHeaderIndex -> [ResponseHeaderIndex]
$cenumFromTo :: ResponseHeaderIndex -> ResponseHeaderIndex -> [ResponseHeaderIndex]
enumFromTo :: ResponseHeaderIndex -> ResponseHeaderIndex -> [ResponseHeaderIndex]
$cenumFromThenTo :: ResponseHeaderIndex
-> ResponseHeaderIndex
-> ResponseHeaderIndex
-> [ResponseHeaderIndex]
enumFromThenTo :: ResponseHeaderIndex
-> ResponseHeaderIndex
-> ResponseHeaderIndex
-> [ResponseHeaderIndex]
Enum, ResponseHeaderIndex
ResponseHeaderIndex
-> ResponseHeaderIndex -> Bounded ResponseHeaderIndex
forall a. a -> a -> Bounded a
$cminBound :: ResponseHeaderIndex
minBound :: ResponseHeaderIndex
$cmaxBound :: ResponseHeaderIndex
maxBound :: ResponseHeaderIndex
Bounded)

-- | The size for 'IndexedHeader' for HTTP Response.
responseMaxIndex :: Int
responseMaxIndex :: Int
responseMaxIndex = ResponseHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum (ResponseHeaderIndex
forall a. Bounded a => a
maxBound :: ResponseHeaderIndex)

responseKeyIndex :: HeaderName -> Int
responseKeyIndex :: HeaderName -> Int
responseKeyIndex HeaderName
hn = case ByteString -> Int
BS.length ByteString
bs of
    Int
4 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"date" -> ResponseHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum ResponseHeaderIndex
ResDate
    Int
6 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"server" -> ResponseHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum ResponseHeaderIndex
ResServer
    Int
13 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"last-modified" -> ResponseHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum ResponseHeaderIndex
ResLastModified
    Int
14 | ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"content-length" -> ResponseHeaderIndex -> Int
forall a. Enum a => a -> Int
fromEnum ResponseHeaderIndex
ResContentLength
    Int
_ -> -Int
1
  where
    bs :: ByteString
bs = HeaderName -> ByteString
forall s. CI s -> s
foldedCase HeaderName
hn

----------------------------------------------------------------

traverseHeader :: [Header] -> Int -> (HeaderName -> Int) -> IndexedHeader
traverseHeader :: RequestHeaders -> Int -> (HeaderName -> Int) -> IndexedHeader
traverseHeader RequestHeaders
hdr Int
maxidx HeaderName -> Int
getIndex = (forall s. ST s (STArray s Int (Maybe ByteString)))
-> IndexedHeader
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s Int (Maybe ByteString)))
 -> IndexedHeader)
-> (forall s. ST s (STArray s Int (Maybe ByteString)))
-> IndexedHeader
forall a b. (a -> b) -> a -> b
$ do
    STArray s Int (Maybe ByteString)
arr <- (Int, Int)
-> Maybe ByteString -> ST s (STArray s Int (Maybe ByteString))
forall i.
Ix i =>
(i, i) -> Maybe ByteString -> ST s (STArray s i (Maybe ByteString))
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
maxidx) Maybe ByteString
forall a. Maybe a
Nothing
    ((HeaderName, ByteString) -> ST s ()) -> RequestHeaders -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STArray s Int (Maybe ByteString)
-> (HeaderName, ByteString) -> ST s ()
forall {m :: * -> *} {a :: * -> * -> *} {a}.
MArray a (Maybe a) m =>
a Int (Maybe a) -> (HeaderName, a) -> m ()
insert STArray s Int (Maybe ByteString)
arr) RequestHeaders
hdr
    STArray s Int (Maybe ByteString)
-> ST s (STArray s Int (Maybe ByteString))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Int (Maybe ByteString)
arr
  where
    insert :: a Int (Maybe a) -> (HeaderName, a) -> m ()
insert a Int (Maybe a)
arr (HeaderName
key, a
val)
        | Int
idx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1 = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise = a Int (Maybe a) -> Int -> Maybe a -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray a Int (Maybe a)
arr Int
idx (a -> Maybe a
forall a. a -> Maybe a
Just a
val)
      where
        idx :: Int
idx = HeaderName -> Int
getIndex HeaderName
key