module Network.OAuth.Http.Request
(
Request(..)
, FieldList()
, Version(..)
, Method(..)
, fromList
, singleton
, empty
, toList
, parseQString
, find
, findWithDefault
, ifindWithDefault
, change
, insert
, replace
, replaces
, union
, unionAll
, showURL
, showQString
, showProtocol
, showAuthority
, showPath
, parseURL
) where
import Control.Monad.State
import Network.OAuth.Http.PercentEncoding
import Network.OAuth.Http.Util
import Data.List (intercalate,isPrefixOf)
import Data.Monoid
import Data.Char (toLower)
import qualified Data.ByteString.Lazy as B
import qualified Data.Binary as Bi
data Method = GET
| POST
| PUT
| DELETE
| TRACE
| CONNECT
| HEAD
deriving (Eq)
data Version = Http10
| Http11
deriving (Eq)
newtype FieldList = FieldList { unFieldList :: [(String,String)] }
deriving (Eq,Ord)
data Request = ReqHttp { version :: Version
, ssl :: Bool
, host :: String
, port :: Int
, method :: Method
, reqHeaders :: FieldList
, pathComps :: [String]
, qString :: FieldList
, reqPayload :: B.ByteString
}
deriving (Eq,Show)
showProtocol :: Request -> String
showProtocol req
| ssl req = "https"
| otherwise = "http"
showAuthority :: Request -> String
showAuthority req
| ssl req && (port req)==443 = host req
| not (ssl req) && (port req)==80 = host req
| otherwise = host req ++":"++ show (port req)
showPath :: Request -> String
showPath = intercalate "/" . map encode . pathComps
showQString :: Request -> String
showQString = show . qString
showURL :: Request -> String
showURL = concat
. zipWith ($) [showProtocol,const "://",showAuthority,showPath,showQString']
. repeat
where showQString' :: Request -> String
showQString' req
| null (unFieldList (qString req)) = ""
| otherwise = '?' : showQString req
parseURL :: String -> Maybe Request
parseURL tape = evalState parser (tape,Just initial)
where parser = do { _parseProtocol
; _parseSymbol (':',True)
; _parseSymbol ('/',True)
; _parseSymbol ('/',True)
; _parseHost
; _parseSymbol (':',False)
; _parsePort
; _parseSymbol ('/',True)
; _parsePath
; _parseSymbol ('?',False)
; _parseQString
; fmap snd get
}
initial = ReqHttp { version = Http11
, ssl = False
, method = GET
, host = "127.0.0.1"
, port = 80
, reqHeaders = fromList []
, pathComps = []
, qString = fromList []
, reqPayload = B.empty
}
parseQString :: String -> FieldList
parseQString tape = evalState parser (tape,Just initial)
where parser = do { _parseQString
; fmap (qstring . snd) get
}
qstring Nothing = fromList []
qstring (Just r) = qString r
initial = ReqHttp { version = Http11
, ssl = False
, method = GET
, host = "127.0.0.1"
, port = 80
, reqHeaders = fromList []
, pathComps = []
, qString = fromList []
, reqPayload = B.empty
}
fromList :: [(String,String)] -> FieldList
fromList = FieldList
toList :: FieldList -> [(String,String)]
toList = unFieldList
singleton :: (String,String) -> FieldList
singleton = fromList . (:[])
empty :: FieldList
empty = fromList []
change :: (String,String) -> FieldList -> FieldList
change kv (FieldList list) = FieldList (change' kv list)
where change' (k,v) ((k0,v0):fs)
| k0==k = (k0,v) : change' (k,v) fs
| otherwise = (k0,v0) : change' (k,v) fs
change' _ [] = []
insert :: (String,String) -> FieldList -> FieldList
insert kv = mappend (FieldList [kv])
replace :: (String,String) -> FieldList -> FieldList
replace (k,v) fs
| null $ find (==k) fs = insert (k,v) fs
| otherwise = change (k,v) fs
replaces :: [(String,String)] -> FieldList -> FieldList
replaces fs field = foldr (replace) field fs
find :: (String -> Bool) -> FieldList -> [String]
find p (FieldList list) = map snd (filter (p.fst) list)
union :: FieldList -> FieldList -> FieldList
union (FieldList as) bs = foldr replace bs as
unionAll :: FieldList -> FieldList -> FieldList
unionAll (FieldList as) bs = foldr insert bs as
findWithDefault :: (String,String) -> FieldList -> String
findWithDefault (key,def) fields
| null values = def
| otherwise = head values
where values = find (==key) fields
ifindWithDefault :: (String,String) -> FieldList -> String
ifindWithDefault (key,def) fields
| null values = def
| otherwise = head values
where values = find (\k -> lower k == lower key) fields
lower = map toLower
_parseProtocol :: State (String,Maybe Request) ()
_parseProtocol = do { (tape,req) <- get
; if ("https" `isPrefixOf` tape)
then put (drop 5 tape,liftM (\r -> r {ssl=True,port=443}) req)
else if ("http" `isPrefixOf` tape)
then put (drop 4 tape,liftM (\r -> r {ssl=False,port=80}) req)
else put ("",Nothing)
}
_parseHost :: State (String,Maybe Request) ()
_parseHost = do { (tape,req) <- get
; let (value,tape') = break (`elem` ":/") tape
; put (tape',liftM (\r -> r {host = value}) req)
}
_parsePort :: State (String,Maybe Request) ()
_parsePort = do { (tape,req) <- get
; let (value,tape') = break (=='/') tape
; case (reads value)
of [(value',"")] -> put (tape',liftM (\r -> r {port = value'}) req)
_ -> put (tape',req)
}
_parsePath :: State (String,Maybe Request) ()
_parsePath = do { (tape,req) <- get
; let (value,tape') = break (=='?') tape
value' = "" : map (decodeWithDefault "") (splitBy (=='/') value)
; put (tape',liftM (\r -> r {pathComps=value'}) req)
}
_parseQString :: State (String,Maybe Request) ()
_parseQString = do { (tape,req) <- get
; let (value,tape') = break (=='#') tape
fields = fromList $ filter (/=("","")) (map parseField (splitBy (=='&') value))
; put (tape',liftM (\r -> r {qString=fields}) req)
}
where parseField tape = let (k,v) = break (=='=') tape
in case (v)
of ('=':v') -> (decodeWithDefault "" k,decodeWithDefault "" v')
_ -> (decodeWithDefault "" k,"")
_parseSymbol :: (Char,Bool) -> State (String,Maybe Request) ()
_parseSymbol (c,required) = do { (tape,req) <- get
; if ([c] `isPrefixOf` tape)
then put (drop 1 tape,req)
else if (required)
then put ("",Nothing)
else put (tape,req)
}
instance Show Method where
showsPrec _ m = case m
of GET -> showString "GET"
POST -> showString "POST"
DELETE -> showString "DELETE"
CONNECT -> showString "CONNECT"
HEAD -> showString "HEAD"
TRACE -> showString "TRACE"
PUT -> showString "PUT"
instance Read Method where
readsPrec _ "GET" = [(GET,"")]
readsPrec _ "POST" = [(POST,"")]
readsPrec _ "DELETE" = [(DELETE,"")]
readsPrec _ "CONNECT" = [(CONNECT,"")]
readsPrec _ "HEAD" = [(HEAD,"")]
readsPrec _ "TRACE" = [(TRACE,"")]
readsPrec _ "PUT" = [(PUT,"")]
readsPrec _ _ = []
instance Read Version where
readsPrec _ "HTTP/1.0" = [(Http10,"")]
readsPrec _ "HTTP/1.1" = [(Http11,"")]
readsPrec _ _ = []
instance Show Version where
showsPrec _ v = case v
of Http10 -> showString "HTTP/1.0"
Http11 -> showString "HTTP/1.1"
instance Show FieldList where
showsPrec _ = showString . intercalate "&" . map showField . unFieldList
where showField (k,v) = encode k ++"="++ encode v
instance Monoid FieldList where
mempty = FieldList []
mappend (FieldList as) (FieldList bs) = FieldList (as `mappend` bs)
instance Bi.Binary FieldList where
put = Bi.put . unFieldList
get = fmap FieldList Bi.get