{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- Copyright (c) 2009, Diego Souza
-- All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--   * Redistributions of source code must retain the above copyright notice,
--     this list of conditions and the following disclaimer.
--   * Redistributions in binary form must reproduce the above copyright notice,
--     this list of conditions and the following disclaimer in the documentation
--     and/or other materials provided with the distribution.
--   * Neither the name of the <ORGANIZATION> nor the names of its contributors
--     may be used to endorse or promote products derived from this software
--     without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | The request currently is only able to represent an HTTP request.
module Network.OAuth.Http.Request 
       ( -- * Types
         Request(..)
       , FieldList()
       , Version(..)
       , Method(..)
       , FormDataPart(..)
       , Content(..)
         -- * conversion of [FormDataPart] to curl's [HttpPost]
       , convertMultipart
         -- * FieldList related functions
       , fromList
       , singleton
       , empty
       , toList
       , parseQString
       , find
       , findWithDefault
       , ifindWithDefault
       , change
       , insert
       , replace
       , replaces
       , union
       , unionAll
         -- * Request related functions
       , showURL
       , showQString
       , showProtocol
       , showAuthority
       , showPath
       , parseURL
       ) where

import Control.Monad.State
import qualified Network.Curl.Post as Curl (HttpPost(..), Content(..))
import qualified Network.Curl.Types as Curl (Long)
import Network.OAuth.Http.PercentEncoding
import Network.OAuth.Http.Util
import Data.List (intercalate,isPrefixOf)
import Data.Monoid
import Data.Char (toLower)
import qualified Data.ByteString.Lazy as B
import qualified Data.Binary as Bi
import Foreign.C.String (CString)

-- | All known HTTP methods
data Method =   GET
              | POST
              | PUT
              | DELETE
              | TRACE
              | CONNECT
              | HEAD
  deriving (Eq)

-- | Supported HTTP versions
data Version =   Http10
               | Http11
  deriving (Eq)

-- | Key-value list.
newtype FieldList = FieldList { unFieldList :: [(String,String)] }
  deriving (Eq,Ord)

data Request = ReqHttp { version          :: Version        -- ^ Protocol version
                       , ssl              :: Bool           -- ^ Wheter or not to use ssl
                       , host             :: String         -- ^ The hostname to connect to
                       , port             :: Int            -- ^ The port to connect to
                       , method           :: Method         -- ^ The HTTP method of the request.
                       , reqHeaders       :: FieldList      -- ^ Request headers
                       , pathComps        :: [String]       -- ^ The path split into components 
                       , qString          :: FieldList      -- ^ The query string, usually set for GET requests
                       , reqPayload       :: B.ByteString   -- ^ The message body (the first/only string part)
                       , multipartPayload :: [FormDataPart] -- ^ The message body (i.e., for multipart/form-data)
                       }
  deriving (Eq,Show)

-- one part of a multipart/form-data POST request
data FormDataPart =
   FormDataPart { postName     :: String
                , contentType  :: Maybe String
                , content      :: Content
                , extraHeaders :: [String]
                -- , extraEntries :: [FormDataPart]    -- commented out in Curl's Post.hs
                , showName     :: Maybe String
                }
   deriving (Eq, Show)

-- data for one part
-- as either a String or a FilePath (for Curl to open and include)
data Content = ContentFile FilePath
             | ContentBuffer CString Curl.Long
             | ContentString String
   deriving (Eq, Show)

-- convert one Part to Curl's HttpPost
convertMultipart :: [FormDataPart] -> [Curl.HttpPost]
convertMultipart parts =
   map convertPart parts
   where
      convertPart :: FormDataPart -> Curl.HttpPost
      convertPart part =
         Curl.HttpPost { Curl.postName     = postName     part
                       , Curl.contentType  = contentType  part
                       , Curl.content      = convertCont (content part)
                       , Curl.extraHeaders = extraHeaders part
                       , Curl.showName     = showName     part
                       }

      convertCont :: Content -> Curl.Content
      convertCont (ContentFile   z) = Curl.ContentFile   z
      convertCont (ContentString y) = Curl.ContentString y
      convertCont (ContentBuffer w x) = Curl.ContentBuffer w x

-- | Show the protocol in use (currently either https or http)
showProtocol :: Request -> String
showProtocol req 
  | ssl req   = "https"
  | otherwise = "http"

-- | Show the host+port path of the request. May return only the host when
--   (ssl=False && port==80) or (ssl=True && port==443).
showAuthority :: Request -> String
showAuthority req 
  | ssl req && (port req)==443      = host req
  | not (ssl req) && (port req)==80 = host req
  | otherwise                       = host req ++":"++ show (port req)

-- | Show the path component of the URL.
showPath :: Request -> String
showPath = intercalate "/" . map encode . pathComps

-- | Show the querty string of the URL.
showQString :: Request -> String
showQString = show . qString

-- | Show the URL.
showURL :: Request -> String
showURL =   concat 
          . zipWith ($) [showProtocol,const "://",showAuthority,showPath,showQString'] 
          . repeat
  where showQString' :: Request -> String
        showQString' req 
          | null (unFieldList (qString req)) = ""
          | otherwise                        = '?' : showQString req

-- | Parse a URL and creates an request type.
parseURL :: String -> Maybe Request
parseURL tape = evalState parser (tape,Just initial)
  where parser = do { _parseProtocol
                    ; _parseSymbol (':',True)
                    ; _parseSymbol ('/',True)
                    ; _parseSymbol ('/',True)
                    ; _parseHost
                    ; _parseSymbol (':',False)
                    ; _parsePort
                    ; _parseSymbol ('/',True)
                    ; _parsePath
                    ; _parseSymbol ('?',False)
                    ; _parseQString
                    ; fmap snd get
                    }
        initial = ReqHttp { version    = Http11
                          , ssl        = False
                          , method     = GET
                          , host       = "127.0.0.1"
                          , port       = 80
                          , reqHeaders = fromList []
                          , pathComps  = []
                          , qString    = fromList []
                          , reqPayload = B.empty
                          , multipartPayload = []
                          }

-- | Parse a query string.
parseQString :: String -> FieldList
parseQString tape = evalState parser (tape,Just initial)
  where parser = do { _parseQString
                    ; fmap (qstring . snd) get
                    }

        qstring Nothing  = fromList []
        qstring (Just r) = qString r

        initial = ReqHttp { version    = Http11
                          , ssl        = False
                          , method     = GET
                          , host       = "127.0.0.1"
                          , port       = 80
                          , reqHeaders = fromList []
                          , pathComps  = []
                          , qString    = fromList []
                          , reqPayload = B.empty
                          , multipartPayload = []
                          }

-- | Creates a FieldList type from a list.
fromList :: [(String,String)] -> FieldList
fromList = FieldList

-- | Transforms a fieldlist into a list type.
toList :: FieldList -> [(String,String)]
toList = unFieldList

-- | Creates a FieldList out from a single element.
singleton :: (String,String) -> FieldList
singleton = fromList . (:[])

-- | Returns an empty fieldlist.
empty :: FieldList
empty = fromList []

-- | Updates all occurrences of a given key with a new value. Does nothing if
--   the values does not exist.
change :: (String,String) -> FieldList -> FieldList
change kv (FieldList list) = FieldList (change' kv list)
  where change' (k,v) ((k0,v0):fs) 
          | k0==k     = (k0,v) : change' (k,v) fs
          | otherwise = (k0,v0) : change' (k,v) fs
        change' _ []  = []

-- | Inserts a new value into a fieldlist.
insert :: (String,String) -> FieldList -> FieldList
insert kv = mappend (FieldList [kv])

-- | Inserts or updates occurrences of a given key.
replace :: (String,String) -> FieldList -> FieldList
replace (k,v) fs 
  | null $ find (==k) fs = insert (k,v) fs
  | otherwise            = change (k,v) fs

-- | Same as /replace/ but work on a list type
replaces :: [(String,String)] -> FieldList -> FieldList
replaces fs field = foldr (replace) field fs

-- | Find keys that satisfy a given predicate.
find :: (String -> Bool) -> FieldList -> [String]
find p (FieldList list) = map snd (filter (p.fst) list)

-- | Combines two fieldsets, but prefere items of the first list.
union :: FieldList -> FieldList -> FieldList
union (FieldList as) bs = foldr replace bs as

-- | Combines two fieldsets keeping duplicates.
unionAll :: FieldList -> FieldList -> FieldList
unionAll (FieldList as) bs = foldr insert bs as

-- | Finds a the value defined in a fieldlist or returns a default value. In
-- the event there are multiple values under the same key the first one is
-- returned.
findWithDefault :: (String,String) -> FieldList -> String
findWithDefault (key,def) fields 
  | null values = def
  | otherwise   = head values
    where values = find (==key) fields

-- | Same as <findWithDefault> but the match is case-insenstiive.
ifindWithDefault :: (String,String) -> FieldList -> String
ifindWithDefault (key,def) fields 
  | null values = def
  | otherwise   = head values
    where values = find (\k -> lower k == lower key) fields
          lower  = map toLower

_parseProtocol :: State (String,Maybe Request) ()
_parseProtocol = do { (tape,req) <- get
                    ; if ("https" `isPrefixOf` tape)
                      then put (drop 5 tape,liftM (\r -> r {ssl=True,port=443}) req)
                      else if ("http" `isPrefixOf` tape) 
                           then put (drop 4 tape,liftM (\r -> r {ssl=False,port=80}) req)
                           else put ("",Nothing)
                    }

_parseHost :: State (String,Maybe Request) ()
_parseHost = do { (tape,req) <- get
                ; let (value,tape') = break (`elem` ":/") tape
                ; put (tape',liftM (\r -> r {host = value}) req)
                }

_parsePort :: State (String,Maybe Request) ()
_parsePort = do { (tape,req) <- get
                ; let (value,tape') = break (=='/') tape
                ; case (reads value)
                  of [(value',"")] -> put (tape',liftM (\r -> r {port = value'}) req)
                     _             -> put (tape',req)
                }

_parsePath :: State (String,Maybe Request) ()
_parsePath = do { (tape,req) <- get
                ; let (value,tape') = break (=='?') tape
                      value'        = "" : map (decodeWithDefault "") (splitBy (=='/') value)
                ; put (tape',liftM (\r -> r {pathComps=value'}) req)
                }

_parseQString :: State (String,Maybe Request) ()
_parseQString = do { (tape,req) <- get
                   ; let (value,tape') = break (=='#') tape
                         fields        = fromList $ filter (/=("","")) (map parseField (splitBy (=='&') value))
                   ; put (tape',liftM (\r -> r {qString=fields}) req)
                   }
  where parseField tape = let (k,v) = break (=='=') tape
                          in case (v)
                             of ('=':v') -> (decodeWithDefault "" k,decodeWithDefault "" v')
                                _        -> (decodeWithDefault "" k,"")

_parseSymbol :: (Char,Bool) -> State (String,Maybe Request) ()
_parseSymbol (c,required) = do { (tape,req) <- get
                               ; if ([c] `isPrefixOf` tape)
                                 then put (drop 1 tape,req)
                                 else if (required) 
                                      then put ("",Nothing)
                                      else put (tape,req)
                               }

instance Show Method where
  showsPrec _ m = case m 
                  of GET     -> showString "GET"
                     POST    -> showString "POST"
                     DELETE  -> showString "DELETE"
                     CONNECT -> showString "CONNECT"
                     HEAD    -> showString "HEAD"
                     TRACE   -> showString "TRACE"
                     PUT     -> showString "PUT"

instance Read Method where
  readsPrec _ "GET"     = [(GET,"")]
  readsPrec _ "POST"    = [(POST,"")]
  readsPrec _ "DELETE"  = [(DELETE,"")]
  readsPrec _ "CONNECT" = [(CONNECT,"")]
  readsPrec _ "HEAD"    = [(HEAD,"")]
  readsPrec _ "TRACE"   = [(TRACE,"")]
  readsPrec _ "PUT"     = [(PUT,"")]
  readsPrec _ _         = []

instance Read Version where
  readsPrec _ "HTTP/1.0" = [(Http10,"")]
  readsPrec _ "HTTP/1.1" = [(Http11,"")]
  readsPrec _ _          = []

instance Show Version where
  showsPrec _ v = case v 
                  of Http10 -> showString "HTTP/1.0"
                     Http11 -> showString "HTTP/1.1"

instance Show FieldList where
  showsPrec _ = showString . intercalate "&" . map showField . unFieldList
    where showField (k,v) = encode k ++"="++ encode v

instance Monoid FieldList where
  mempty  = FieldList []
  mappend (FieldList as) (FieldList bs) = FieldList (as `mappend` bs)

instance Bi.Binary FieldList where
  put = Bi.put . unFieldList
  get = fmap FieldList Bi.get