module Network.Wai.Handler.Warp.Request (
    recvRequest
  , headerLines
  , pauseTimeoutKey
  ) where
import qualified Control.Concurrent as Conc (yield)
import Control.Exception (throwIO)
import Data.Array ((!))
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.CaseInsensitive as CI
import qualified Data.IORef as I
import qualified Network.HTTP.Types as H
import Network.Socket (SockAddr)
import Network.Wai
import Network.Wai.Handler.Warp.Conduit
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.RequestHeader
import Network.Wai.Handler.Warp.Settings (Settings, settingsNoParsePath)
import qualified Network.Wai.Handler.Warp.Timeout as Timeout
import Network.Wai.Handler.Warp.Types
import Network.Wai.Internal
import Prelude hiding (lines)
import Control.Monad (when)
import qualified Data.Vault.Lazy as Vault
import System.IO.Unsafe (unsafePerformIO)
maxTotalHeaderLength :: Int
maxTotalHeaderLength = 50 * 1024
recvRequest :: Settings
            -> Connection
            -> InternalInfo
            -> SockAddr 
            -> Source 
            -> IO (Request
                  ,Maybe (I.IORef Int)
                  ,IndexedHeader
                  ,IO ByteString) 
            
            
            
            
recvRequest settings conn ii addr src = do
    hdrlines <- headerLines src
    (method, unparsedPath, path, query, httpversion, hdr) <- parseHeaderLines hdrlines
    let idxhdr = indexRequestHeader hdr
        expect = idxhdr ! idxExpect
        cl = idxhdr ! idxContentLength
        te = idxhdr ! idxTransferEncoding
        handle100Continue = handleExpect conn httpversion expect
    (rbody, remainingRef, bodyLength) <- bodyAndSource src cl te
    
    rbody' <- timeoutBody remainingRef th rbody handle100Continue
    
    rbodyFlush <- timeoutBody remainingRef th rbody (return ())
    let req = Request {
            requestMethod     = method
          , httpVersion       = httpversion
          , pathInfo          = H.decodePathSegments path
          , rawPathInfo       = if settingsNoParsePath settings then unparsedPath else path
          , rawQueryString    = query
          , queryString       = H.parseQuery query
          , requestHeaders    = hdr
          , isSecure          = False
          , remoteHost        = addr
          , requestBody       = rbody'
          , vault             = Vault.insert pauseTimeoutKey
                                (Timeout.pause th)
                                Vault.empty
          , requestBodyLength = bodyLength
          , requestHeaderHost = idxhdr ! idxHost
          , requestHeaderRange = idxhdr ! idxRange
          }
    return (req, remainingRef, idxhdr, rbodyFlush)
  where
    th = threadHandle ii
headerLines :: Source -> IO [ByteString]
headerLines src = do
    bs <- readSource src
    if S.null bs
        then throwIO ConnectionClosedByPeer
        else push src (THStatus 0 id id) bs
handleExpect :: Connection
             -> H.HttpVersion
             -> Maybe HeaderValue
             -> IO ()
handleExpect conn ver (Just "100-continue") = do
    connSendAll conn continue
    Conc.yield
  where
    continue
      | ver == H.http11 = "HTTP/1.1 100 Continue\r\n\r\n"
      | otherwise       = "HTTP/1.0 100 Continue\r\n\r\n"
handleExpect _    _   _                     = return ()
bodyAndSource :: Source
              -> Maybe HeaderValue 
              -> Maybe HeaderValue 
              -> IO (IO ByteString
                    ,Maybe (I.IORef Int)
                    ,RequestBodyLength
                    )
bodyAndSource src cl te
  | chunked = do
      csrc <- mkCSource src
      return (readCSource csrc, Nothing, ChunkedBody)
  | otherwise = do
      isrc@(ISource _ remaining) <- mkISource src len
      return (readISource isrc, Just remaining, bodyLen)
  where
    len = toLength cl
    bodyLen = KnownLength $ fromIntegral len
    chunked = isChunked te
toLength :: Maybe HeaderValue -> Int
toLength Nothing   = 0
toLength (Just bs) = readInt bs
isChunked :: Maybe HeaderValue -> Bool
isChunked (Just bs) = CI.foldCase bs == "chunked"
isChunked _         = False
timeoutBody :: Maybe (I.IORef Int) 
            -> Timeout.Handle
            -> IO ByteString
            -> IO ()
            -> IO (IO ByteString)
timeoutBody remainingRef timeoutHandle rbody handle100Continue = do
    isFirstRef <- I.newIORef True
    let checkEmpty =
            case remainingRef of
                Nothing -> return . S.null
                Just ref -> \bs -> if S.null bs
                    then return True
                    else do
                        x <- I.readIORef ref
                        return $! x <= 0
    return $ do
        isFirst <- I.readIORef isFirstRef
        when isFirst $ do
            
            
            handle100Continue
            
            
            
            Timeout.resume timeoutHandle
            I.writeIORef isFirstRef False
        bs <- rbody
        
        
        
        
        isEmpty <- checkEmpty bs
        when isEmpty (Timeout.pause timeoutHandle)
        return bs
type BSEndo = ByteString -> ByteString
type BSEndoList = [ByteString] -> [ByteString]
data THStatus = THStatus
     !Int 
    BSEndoList 
    BSEndo 
push :: Source -> THStatus -> ByteString -> IO [ByteString]
push src (THStatus len lines prepend) bs'
        
        | len > maxTotalHeaderLength = throwIO OverLargeHeader
        | otherwise = push' mnl
  where
    bs = prepend bs'
    bsLen = S.length bs
    mnl = do
        nl <- S.elemIndex 10 bs
        
        
        if bsLen > nl + 1 then
            let c = S.index bs (nl + 1)
                b = case nl of
                      0 -> True
                      1 -> S.index bs 0 == 13
                      _ -> False
            in Just (nl, not b && (c == 32 || c == 9))
            else
            Just (nl, False)
    
    push' :: Maybe (Int, Bool) -> IO [ByteString]
    
    
    push' Nothing = do
        bst <- readSource' src
        when (S.null bst) $ throwIO IncompleteHeaders
        push src status bst
      where
        len' = len + bsLen
        prepend' = S.append bs
        status = THStatus len' lines prepend'
    
    push' (Just (end, True)) = push src status rest
      where
        rest = S.drop (end + 1) bs
        prepend' = S.append (SU.unsafeTake (checkCR bs end) bs)
        len' = len + end
        status = THStatus len' lines prepend'
    
    push' (Just (end, False))
      
      | S.null line = do
            when (start < bsLen) $ leftoverSource src (SU.unsafeDrop start bs)
            return (lines [])
      
      | otherwise   = let len' = len + start
                          lines' = lines . (line:)
                          status = THStatus len' lines' id
                      in if start < bsLen then
                             
                             let bs'' = SU.unsafeDrop start bs
                              in push src status bs''
                           else do
                             
                             bst <- readSource' src
                             when (S.null bs) $ throwIO IncompleteHeaders
                             push src status bst
      where
        start = end + 1 
        line = SU.unsafeTake (checkCR bs end) bs
checkCR :: ByteString -> Int -> Int
checkCR bs pos = if pos > 0 && 13 == S.index bs p then p else pos 
  where
    !p = pos  1
pauseTimeoutKey :: Vault.Key (IO ())
pauseTimeoutKey = unsafePerformIO Vault.newKey