module Network.OAuth.Http.Request (
Request(..)
,FieldList()
,Version(..)
,Method(..)
,fromList
,singleton
,empty
,toList
,parseQString
,find
,findWithDefault
,ifindWithDefault
,change
,insert
,replace
,replaces
,union
,unionAll
,showURL
,showQString
,showProtocol
,showAuthority
,showPath
,parseURL
) where
import Control.Monad.State
import Network.OAuth.Http.PercentEncoding
import Network.OAuth.Http.Util
import Data.List (intercalate,isPrefixOf)
import Data.Monoid
import Data.Char (toLower)
import qualified Data.ByteString.Lazy as B
import qualified Data.Binary as Bi
data Method = GET
| POST
| PUT
| DELETE
| TRACE
| CONNECT
| HEAD
deriving (Eq)
data Version = Http10
| Http11
deriving (Eq)
newtype FieldList = FieldList {unFieldList :: [(String,String)]}
deriving (Eq,Ord)
data Request = ReqHttp {version :: Version
,ssl :: Bool
,host :: String
,port :: Int
,method :: Method
,reqHeaders :: FieldList
,pathComps :: [String]
,qString :: FieldList
,reqPayload :: B.ByteString
}
deriving (Eq,Show)
showProtocol :: Request -> String
showProtocol req | ssl req = "https"
| otherwise = "http"
showAuthority :: Request -> String
showAuthority req | ssl req && (port req)==443 = host req
| not (ssl req) && (port req)==80 = host req
| otherwise = host req ++":"++ show (port req)
showPath :: Request -> String
showPath = intercalate "/" . map encode . pathComps
showQString :: Request -> String
showQString = show . qString
showURL :: Request -> String
showURL = concat
. zipWith ($) [showProtocol,const "://",showAuthority,showPath,showQString']
. repeat
where showQString' :: Request -> String
showQString' req | null (unFieldList (qString req)) = ""
| otherwise = '?' : showQString req
parseURL :: String -> Maybe Request
parseURL tape = evalState parser (tape,Just initial)
where parser = do _parseProtocol
_parseSymbol (':',True)
_parseSymbol ('/',True)
_parseSymbol ('/',True)
_parseHost
_parseSymbol (':',False)
_parsePort
_parseSymbol ('/',True)
_parsePath
_parseSymbol ('?',False)
_parseQString
fmap snd get
initial = ReqHttp {version = Http11
,ssl = False
,method = GET
,host = "127.0.0.1"
,port = 80
,reqHeaders = fromList []
,pathComps = []
,qString = fromList []
,reqPayload = B.empty
}
parseQString :: String -> FieldList
parseQString tape = evalState parser (tape,Just initial)
where parser = do _parseQString
(fmap (qstring . snd) get)
qstring Nothing = fromList []
qstring (Just r) = qString r
initial = ReqHttp {version = Http11
,ssl = False
,method = GET
,host = "127.0.0.1"
,port = 80
,reqHeaders = fromList []
,pathComps = []
,qString = fromList []
,reqPayload = B.empty
}
fromList :: [(String,String)] -> FieldList
fromList = FieldList
toList :: FieldList -> [(String,String)]
toList = unFieldList
singleton :: (String,String) -> FieldList
singleton = fromList . (:[])
empty :: FieldList
empty = fromList []
change :: (String,String) -> FieldList -> FieldList
change kv (FieldList list) = FieldList (change' kv list)
where change' (k,v) ((k0,v0):fs) | k0==k = (k0,v) : change' (k,v) fs
| otherwise = (k0,v0) : change' (k,v) fs
change' _ [] = []
insert :: (String,String) -> FieldList -> FieldList
insert kv = mappend (FieldList [kv])
replace :: (String,String) -> FieldList -> FieldList
replace (k,v) fs | null $ find (==k) fs = insert (k,v) fs
| otherwise = change (k,v) fs
replaces :: [(String,String)] -> FieldList -> FieldList
replaces fs field = foldr (replace) field fs
find :: (String -> Bool) -> FieldList -> [String]
find p (FieldList list) = map snd (filter (p.fst) list)
union :: FieldList -> FieldList -> FieldList
union (FieldList as) bs = foldr replace bs as
unionAll :: FieldList -> FieldList -> FieldList
unionAll (FieldList as) bs = foldr insert bs as
findWithDefault :: (String,String) -> FieldList -> String
findWithDefault (key,def) fields | null values = def
| otherwise = head values
where values = find (==key) fields
ifindWithDefault :: (String,String) -> FieldList -> String
ifindWithDefault (key,def) fields | null values = def
| otherwise = head values
where values = find (\k -> lower k == lower key) fields
lower = map toLower
_parseProtocol :: State (String,Maybe Request) ()
_parseProtocol = get >>= \(tape,req) ->
if ("https" `isPrefixOf` tape)
then put (drop 5 tape,liftM (\r -> r {ssl=True,port=443}) req)
else if ("http" `isPrefixOf` tape)
then put (drop 4 tape,liftM (\r -> r {ssl=False,port=80}) req)
else put ("",Nothing)
_parseHost :: State (String,Maybe Request) ()
_parseHost = get >>= \(tape,req) ->
let (value,tape') = break (`elem` ":/") tape
in put (tape',liftM (\r -> r {host = value}) req)
_parsePort :: State (String,Maybe Request) ()
_parsePort = get >>= \(tape,req) ->
let (value,tape') = break (=='/') tape
in case (reads value)
of [(value',"")] -> put (tape',liftM (\r -> r {port = value'}) req)
_ -> put (tape',req)
_parsePath :: State (String,Maybe Request) ()
_parsePath = get >>= \(tape,req) ->
let (value,tape') = break (=='?') tape
value' = "" : map (decodeWithDefault "") (splitBy (=='/') value)
in put (tape',liftM (\r -> r {pathComps=value'}) req)
_parseQString :: State (String,Maybe Request) ()
_parseQString = get >>= \(tape,req) ->
let (value,tape') = break (=='#') tape
fields = fromList $ filter (/=("","")) (map parseField (splitBy (=='&') value))
in put (tape',liftM (\r -> r {qString=fields}) req)
where parseField tape = let (k,v) = break (=='=') tape
in case (v)
of ('=':v') -> (decodeWithDefault "" k,decodeWithDefault "" v')
_ -> (decodeWithDefault "" k,"")
_parseSymbol :: (Char,Bool) -> State (String,Maybe Request) ()
_parseSymbol (c,required) = get >>= \(tape,req) ->
if ([c] `isPrefixOf` tape)
then put (drop 1 tape,req)
else if (required)
then put ("",Nothing)
else put (tape,req)
instance Show Method where
showsPrec _ m = case m
of GET -> showString "GET"
POST -> showString "POST"
DELETE -> showString "DELETE"
CONNECT -> showString "CONNECT"
HEAD -> showString "HEAD"
TRACE -> showString "TRACE"
PUT -> showString "PUT"
instance Read Method where
readsPrec _ "GET" = [(GET,"")]
readsPrec _ "POST" = [(POST,"")]
readsPrec _ "DELETE" = [(DELETE,"")]
readsPrec _ "CONNECT" = [(CONNECT,"")]
readsPrec _ "HEAD" = [(HEAD,"")]
readsPrec _ "TRACE" = [(TRACE,"")]
readsPrec _ "PUT" = [(PUT,"")]
readsPrec _ _ = []
instance Read Version where
readsPrec _ "HTTP/1.0" = [(Http10,"")]
readsPrec _ "HTTP/1.1" = [(Http11,"")]
readsPrec _ _ = []
instance Show Version where
showsPrec _ v = case v
of Http10 -> showString "HTTP/1.0"
Http11 -> showString "HTTP/1.1"
instance Show FieldList where
showsPrec _ = showString . intercalate "&" . map showField . unFieldList
where showField (k,v) = encode k ++"="++ encode v
instance Monoid FieldList where
mempty = FieldList []
mappend (FieldList as) (FieldList bs) = FieldList (as `mappend` bs)
instance Bi.Binary FieldList where
put = Bi.put . unFieldList
get = fmap FieldList Bi.get