module Network.HTTP.Client.Request
    ( parseUrl
    , setUriRelative
    , getUri
    , setUri
    , browserDecompress
    , alwaysDecompress
    , addProxy
    , applyBasicAuth
    , applyBasicProxyAuth
    , urlEncodedBody
    , needsGunzip
    , requestBuilder
    , useDefaultTimeout
    , setQueryString
    , streamFile
    , observedStreamFile
    ) where
import Data.Int (Int64)
import Data.Maybe (fromMaybe, isJust)
import Data.Monoid (mempty, mappend)
import Data.String (IsString(..))
import Data.Char (toLower)
import Control.Applicative ((<$>))
import Control.Monad (when, unless)
import Numeric (showHex)
import Data.Default.Class (Default (def))
import Blaze.ByteString.Builder (Builder, fromByteString, fromLazyByteString, toByteStringIO, flush)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import qualified Network.HTTP.Types as W
import Network.URI (URI (..), URIAuth (..), parseURI, relativeTo, escapeURIString, isAllowedInURI)
import Control.Monad.IO.Class (liftIO)
import Control.Exception (Exception, toException, throw, throwIO, IOException)
import qualified Control.Exception as E
import qualified Data.CaseInsensitive as CI
import qualified Data.ByteString.Base64 as B64
import Network.HTTP.Client.Types
import Network.HTTP.Client.Util
import Network.HTTP.Client.Connection
import Network.HTTP.Client.Util (readDec, (<>))
import Data.Time.Clock
import Control.Monad.Catch (MonadThrow, throwM)
import Data.IORef
import System.IO (withBinaryFile, hTell, hFileSize, Handle, IOMode (ReadMode))
parseUrl :: MonadThrow m => String -> m Request
parseUrl s =
    case parseURI (encode s) of
        Just uri -> setUri def uri
        Nothing  -> throwM $ InvalidUrlException s "Invalid URL"
  where
    encode = escapeURIString isAllowedInURI
setUriRelative :: MonadThrow m => Request -> URI -> m Request
setUriRelative req uri =
#if MIN_VERSION_network(2,4,0)
    setUri req $ uri `relativeTo` getUri req
#else
    case uri `relativeTo` getUri req of
        Just uri' -> setUri req uri'
        Nothing   -> throwM $ InvalidUrlException (show uri) "Invalid URL"
#endif
getUri :: Request -> URI
getUri req = URI
    { uriScheme = if secure req
                    then "https:"
                    else "http:"
    , uriAuthority = Just URIAuth
        { uriUserInfo = ""
        , uriRegName = S8.unpack $ host req
        , uriPort = ':' : show (port req)
        }
    , uriPath = S8.unpack $ path req
    , uriQuery = S8.unpack $ queryString req
    , uriFragment = ""
    }
setUri :: MonadThrow m => Request -> URI -> m Request
setUri req uri = do
    sec <- parseScheme uri
    auth <- maybe (failUri "URL must be absolute") return $ uriAuthority uri
    if not . null $ uriUserInfo auth
        then failUri "URL auth not supported; use applyBasicAuth instead"
        else return ()
    port' <- parsePort sec auth
    return req
        { host = S8.pack $ uriRegName auth
        , port = port'
        , secure = sec
        , path = S8.pack $
                    if null $ uriPath uri
                        then "/"
                        else uriPath uri
        , queryString = S8.pack $ uriQuery uri
        }
  where
    failUri :: MonadThrow m => String -> m a
    failUri = throwM . InvalidUrlException (show uri)
    parseScheme URI{uriScheme = scheme} =
        case map toLower scheme of
            "http:"  -> return False
            "https:" -> return True
            _        -> failUri "Invalid scheme"
    parsePort sec URIAuth{uriPort = portStr} =
        case portStr of
            
            ':':rest -> maybe
                (failUri "Invalid port")
                return
                (readDec rest)
            
            _ -> case sec of
                    False  -> return 80
                    True  -> return 443
instance Show Request where
    show x = unlines
        [ "Request {"
        , "  host                 = " ++ show (host x)
        , "  port                 = " ++ show (port x)
        , "  secure               = " ++ show (secure x)
        , "  requestHeaders       = " ++ show (requestHeaders x)
        , "  path                 = " ++ show (path x)
        , "  queryString          = " ++ show (queryString x)
        
        , "  method               = " ++ show (method x)
        , "  proxy                = " ++ show (proxy x)
        , "  rawBody              = " ++ show (rawBody x)
        , "  redirectCount        = " ++ show (redirectCount x)
        , "  responseTimeout      = " ++ show (responseTimeout x)
        , "  requestVersion       = " ++ show (requestVersion x)
        , "}"
        ]
useDefaultTimeout :: Maybe Int
useDefaultTimeout = Just (3425)
instance Default Request where
    def = Request
        { host = "localhost"
        , port = 80
        , secure = False
        , requestHeaders = []
        , path = "/"
        , queryString = S8.empty
        , requestBody = RequestBodyLBS L.empty
        , method = "GET"
        , proxy = Nothing
        , hostAddress = Nothing
        , rawBody = False
        , decompress = browserDecompress
        , redirectCount = 10
        , checkStatus = \s@(W.Status sci _) hs cookie_jar ->
            if 200 <= sci && sci < 300
                then Nothing
                else Just $ toException $ StatusCodeException s hs cookie_jar
        , responseTimeout = useDefaultTimeout
        , getConnectionWrapper = \mtimeout exc f ->
            case mtimeout of
                Nothing -> fmap ((,) Nothing) f
                Just timeout' -> do
                    before <- getCurrentTime
                    mres <- timeout timeout' f
                    case mres of
                        Nothing -> throwIO exc
                        Just res -> do
                            now <- getCurrentTime
                            let timeSpentMicro = diffUTCTime now before * 1000000
                                remainingTime = round $ fromIntegral timeout'  timeSpentMicro
                            if remainingTime <= 0
                                then throwIO exc
                                else return (Just remainingTime, res)
        , cookieJar = Just def
        , requestVersion = W.http11
        , onRequestBodyException = \se ->
            case E.fromException se of
                Just (_ :: IOException) -> return ()
                Nothing -> throwIO se
        }
instance IsString Request where
    fromString s =
        case parseUrl s of
            Left e -> throw e
            Right r -> r
alwaysDecompress :: S.ByteString -> Bool
alwaysDecompress = const True
browserDecompress :: S.ByteString -> Bool
browserDecompress = (/= "application/x-tar")
applyBasicAuth :: S.ByteString -> S.ByteString -> Request -> Request
applyBasicAuth user passwd req =
    req { requestHeaders = authHeader : requestHeaders req }
  where
    authHeader = (CI.mk "Authorization", basic)
    basic = S8.append "Basic " (B64.encode $ S8.concat [ user, ":", passwd ])
addProxy :: S.ByteString -> Int -> Request -> Request
addProxy hst prt req =
    req { proxy = Just $ Proxy hst prt }
applyBasicProxyAuth :: S.ByteString -> S.ByteString -> Request -> Request
applyBasicProxyAuth user passwd req =
    req { requestHeaders = authHeader : requestHeaders req }
  where
    authHeader = (CI.mk "Proxy-Authorization", basic)
    basic = S8.append "Basic " (B64.encode $ S8.concat [ user , ":", passwd ])
urlEncodedBody :: [(S.ByteString, S.ByteString)] -> Request -> Request
urlEncodedBody headers req = req
    { requestBody = RequestBodyLBS body
    , method = "POST"
    , requestHeaders =
        (ct, "application/x-www-form-urlencoded")
      : filter (\(x, _) -> x /= ct) (requestHeaders req)
    }
  where
    ct = "Content-Type"
    body = L.fromChunks . return $ W.renderSimpleQuery False headers
needsGunzip :: Request
            -> [W.Header] 
            -> Bool
needsGunzip req hs' =
        not (rawBody req)
     && ("content-encoding", "gzip") `elem` hs'
     && decompress req (fromMaybe "" $ lookup "content-type" hs')
requestBuilder :: Request -> Connection -> IO (Maybe (IO ()))
requestBuilder req Connection {..}
    | expectContinue = flushHeaders >> return (Just (checkBadSend sendLater))
    | otherwise      = sendNow      >> return Nothing
  where
    expectContinue   = Just "100-continue" == lookup "Expect" (requestHeaders req)
    checkBadSend f   = f `E.catch` onRequestBodyException req
    writeBuilder     = toByteStringIO connectionWrite
    writeHeadersWith = writeBuilder . (builder `mappend`)
    flushHeaders     = writeHeadersWith flush
    (contentLength, sendNow, sendLater) =
        case requestBody req of
            RequestBodyLBS lbs ->
                let body  = fromLazyByteString lbs
                    now   = checkBadSend $ writeHeadersWith body
                    later = writeBuilder body
                in (Just (L.length lbs), now, later)
            RequestBodyBS bs ->
                let body  = fromByteString bs
                    now   = checkBadSend $ writeHeadersWith body
                    later = writeBuilder body
                in (Just (fromIntegral $ S.length bs), now, later)
            RequestBodyBuilder len body ->
                let now   = checkBadSend $ writeHeadersWith body
                    later = writeBuilder body
                in (Just len, now, later)
            
            
            RequestBodyStream len stream ->
                let body = writeStream False stream
                    
                    
                    
                    
                    now  = flushHeaders >> checkBadSend body
                in (Just len, now, body)
            RequestBodyStreamChunked stream ->
                let body = writeStream True stream
                    now  = flushHeaders >> checkBadSend body
                in (Nothing, now, body)
    writeStream isChunked withStream =
        withStream loop
      where
        loop stream = do
            bs <- stream
            when isChunked $
                connectionWrite $ S8.pack $ showHex (S.length bs) (if S.null bs then "\r\n\r\n" else "\r\n")
            unless (S.null bs) $ do
                connectionWrite bs
                when isChunked $ connectionWrite "\r\n"
                loop stream
    hh
        | port req == 80 && not (secure req) = host req
        | port req == 443 && secure req = host req
        | otherwise = host req <> S8.pack (':' : show (port req))
    requestProtocol
        | secure req = fromByteString "https://"
        | otherwise  = fromByteString "http://"
    requestHostname
        | isJust (proxy req) && not (secure req)
            = requestProtocol <> fromByteString hh
        | otherwise          = mempty
    contentLengthHeader (Just contentLength') =
            if method req `elem` ["GET", "HEAD"] && contentLength' == 0
                then id
                else (:) ("Content-Length", S8.pack $ show contentLength')
    contentLengthHeader Nothing = (:) ("Transfer-Encoding", "chunked")
    acceptEncodingHeader =
        case lookup "Accept-Encoding" $ requestHeaders req of
            Nothing -> (("Accept-Encoding", "gzip"):)
            Just "" -> filter (\(k, _) -> k /= "Accept-Encoding")
            Just _ -> id
    hostHeader x =
        case lookup "Host" x of
            Nothing -> ("Host", hh) : x
            Just{} -> x
    headerPairs :: W.RequestHeaders
    headerPairs = hostHeader
                $ acceptEncodingHeader
                $ contentLengthHeader contentLength
                $ requestHeaders req
    builder :: Builder
    builder =
            fromByteString (method req)
            <> fromByteString " "
            <> requestHostname
            <> (case S8.uncons $ path req of
                    Just ('/', _) -> fromByteString $ path req
                    _ -> fromChar '/' <> fromByteString (path req))
            <> (case S8.uncons $ queryString req of
                    Nothing -> mempty
                    Just ('?', _) -> fromByteString $ queryString req
                    _ -> fromChar '?' <> fromByteString (queryString req))
            <> (case requestVersion req of
                    W.HttpVersion 1 1 -> fromByteString " HTTP/1.1\r\n"
                    W.HttpVersion 1 0 -> fromByteString " HTTP/1.0\r\n"
                    version ->
                        fromChar ' ' <>
                        fromShow version <>
                        fromByteString "\r\n")
            <> foldr
                (\a b -> headerPairToBuilder a <> b)
                (fromByteString "\r\n")
                headerPairs
    headerPairToBuilder (k, v) =
           fromByteString (CI.original k)
        <> fromByteString ": "
        <> fromByteString v
        <> fromByteString "\r\n"
setQueryString :: [(S.ByteString, Maybe S.ByteString)] -> Request -> Request
setQueryString qs req = req { queryString = W.renderQuery True qs }
streamFile :: FilePath -> IO RequestBody
streamFile = observedStreamFile (\_ -> return ())
observedStreamFile :: (StreamFileStatus -> IO ()) -> FilePath -> IO RequestBody
observedStreamFile obs path = do
    size <- fromIntegral <$> withBinaryFile path ReadMode hFileSize
    let filePopper :: Handle -> Popper
        filePopper h = do
            bs <- S.hGetSome h defaultChunkSize
            currentPosition <- fromIntegral <$> hTell h
            obs $ StreamFileStatus
                { fileSize = size
                , readSoFar = currentPosition
                , thisChunkSize = S.length bs
                }
            return bs
        givesFilePopper :: GivesPopper ()
        givesFilePopper k = withBinaryFile path ReadMode $ \h -> do
            k (filePopper h)
    return $ RequestBodyStream size givesFilePopper