{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.HTTP.Conduit.Response
    ( Response (..)
    , getRedirectedRequest
    , getResponse
    , lbsResponse
    ) where

import Control.Arrow (first)
import Control.Monad (liftM)

import Control.Exception (throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)

import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L

import qualified Data.CaseInsensitive as CI

import Data.Default (def)

import Data.Conduit
import Data.Conduit.Internal (ResumableSource (..), Pipe (..))
import qualified Data.Conduit.Zlib as CZ
import qualified Data.Conduit.List as CL

import qualified Network.HTTP.Types as W
import Network.URI (parseURIReference)

import Network.HTTP.Conduit.Types (Response (..), CookieJar)

import Network.HTTP.Conduit.Manager
import Network.HTTP.Conduit.Request
import Network.HTTP.Conduit.Util
import Network.HTTP.Conduit.Chunk
import Network.HTTP.Conduit.Parser (sinkHeaders)

import Data.Void (Void, absurd)

import System.Timeout.Lifted (timeout)
#if MIN_VERSION_conduit(1, 0, 0)
import Data.Conduit.Internal (ConduitM (..))
#endif

-- | If a request is a redirection (status code 3xx) this function will create
-- a new request from the old request, the server headers returned with the
-- redirection, and the redirection code itself. This function returns 'Nothing'
-- if the code is not a 3xx, there is no 'location' header included, or if the
-- redirected response couldn't be parsed with 'parseUrl'.
--
-- If a user of this library wants to know the url chain that results from a
-- specific request, that user has to re-implement the redirect-following logic
-- themselves. An example of that might look like this:
--
-- > myHttp req man = do
-- >    (res, redirectRequests) <- (`runStateT` []) $
-- >         'httpRedirect'
-- >             9000
-- >             (\req' -> do
-- >                res <- http req'{redirectCount=0} man
-- >                modify (\rqs -> req' : rqs)
-- >                return (res, getRedirectedRequest req' (responseHeaders res) (responseCookieJar res) (W.statusCode (responseStatus res))
-- >                )
-- >             'lift'
-- >             req
-- >    applyCheckStatus (checkStatus req) res
-- >    return redirectRequests
getRedirectedRequest :: Request m -> W.ResponseHeaders -> CookieJar -> Int -> Maybe (Request m)
getRedirectedRequest req hs cookie_jar code
    | 300 <= code && code < 400 = do
        l' <- lookup "location" hs
        req' <- setUriRelative req =<< parseURIReference (S8.unpack l')
        return $
            if code == 302 || code == 303
                -- According to the spec, this should *only* be for status code
                -- 303. However, almost all clients mistakenly implement it for
                -- 302 as well. So we have to be wrong like everyone else...
                then req'
                    { method = "GET"
                    , requestBody = RequestBodyBS ""
                    , cookieJar = cookie_jar'
                    }
                else req' {cookieJar = cookie_jar'}
    | otherwise = Nothing
  where
    cookie_jar' = fmap (const cookie_jar) $ cookieJar req

-- | Convert a 'Response' that has a 'Source' body to one with a lazy
-- 'L.ByteString' body.
lbsResponse :: Monad m
            => Response (ResumableSource m S8.ByteString)
            -> m (Response L.ByteString)
lbsResponse res = do
    bss <- responseBody res $$+- CL.consume
    return res
        { responseBody = L.fromChunks bss
        }

-- | This function can\'t be a Conduit, since it would lose leftovers.
checkHeaderLength :: MonadResource m
                  => Int
                  -> Pipe S8.ByteString S8.ByteString Void u m r
                  -> Pipe S8.ByteString S8.ByteString Void u m r
checkHeaderLength len NeedInput{}
    | len <= 0 = liftIO $ throwIO OverlongHeaders
checkHeaderLength len (NeedInput pushI closeI) = NeedInput
    (\bs -> checkHeaderLength
        (len - S8.length bs)
        (pushI bs)) closeI
checkHeaderLength len (PipeM msink) = PipeM (liftM (checkHeaderLength len) msink)
checkHeaderLength _ s@Done{} = s
checkHeaderLength _ (HaveOutput _ _ o) = absurd o
checkHeaderLength len (Leftover p i) = Leftover (checkHeaderLength (len + S.length i) p) i

getResponse :: (MonadResource m, MonadBaseControl IO m)
            => ConnRelease m
            -> Maybe Int
            -> Request m
            -> Source m S8.ByteString
            -> m (Response (ResumableSource m S8.ByteString))
getResponse connRelease timeout'' req@(Request {..}) src1 = do
    let timeout' =
            case timeout'' of
                Nothing -> id
                Just useconds -> \ma -> do
                    x <- timeout useconds ma
                    case x of
                        Nothing -> liftIO $ throwIO ResponseTimeout
                        Just y -> return y
    (src2, ((vbs, sc, sm), hs)) <- timeout' $ src1 $$+
#if MIN_VERSION_conduit(1, 0, 0)
        ConduitM (checkHeaderLength 4096 $ unConduitM sinkHeaders)
#else
        (checkHeaderLength 4096 sinkHeaders)
#endif
    let version = if vbs == "1.1" then W.http11 else W.http10
    let s = W.Status sc sm
    let hs' = map (first CI.mk) hs
    let mcl = lookup "content-length" hs' >>= readDec . S8.unpack

    -- should we put this connection back into the connection manager?
    let toPut = Just "close" /= lookup "connection" hs' && vbs /= "1.0"
    let cleanup bodyConsumed = connRelease $ if toPut && bodyConsumed then Reuse else DontReuse

    -- RFC 2616 section 4.4_1 defines responses that must not include a body
    body <-
        if hasNoBody method sc || mcl == Just 0
            then do
                cleanup True
                (rsrc, ()) <- return () $$+ return ()
                return rsrc
            else do
                let isChunked = ("transfer-encoding", "chunked") `elem` hs'
                    src3 =
                        if isChunked
                            then fmapResume ($= chunkedConduit rawBody) src2
                            else
                                case mcl of
                                    Just len -> fmapResume ($= requireLength len) src2
                                    Nothing  -> src2
                    src4 =
                        if needsGunzip req hs'
                            then fmapResume ($= (if isChunked then ungzipChunked else CZ.ungzip)) src3
                            else src3
                return $ addCleanup' cleanup src4

    return $ Response s version hs' body def
  where
    -- When a body is both chunked and gzipped, we need to flush each chunk
    -- immediately to ensure streaming behavior.
    ungzipChunked =
        CL.concatMap (\x -> [Chunk x, Flush])
        =$= CZ.decompressFlush (CZ.WindowBits 31)
        =$= awaitForever unChunk
      where
        unChunk Flush = return ()
        unChunk (Chunk x) = yield x
    fmapResume f (ResumableSource src m) = ResumableSource (f src) m
    addCleanup' f (ResumableSource src m) = ResumableSource (addCleanup f src) (m >> f False)

-- | Ensure that the stream has exactly the given length.
requireLength :: MonadIO m => Int -> Conduit S.ByteString m S.ByteString
requireLength total =
    loop total
  where
    loop 0 = return ()
    loop i =
        await >>= maybe
            (liftIO $ throwIO $ ResponseBodyTooShort
                (fromIntegral total)
                (fromIntegral $ total - i))
            go
      where
        go bs =
            case compare i l of
                EQ -> yield bs
                LT -> do
                    let (x, y) = S.splitAt i bs
                    leftover y
                    yield x
                GT -> yield bs >> loop (i - l)
          where
            l = S.length bs