module Network.HTTP.Enumerator
( Request (..)
, Response (..)
, http
, parseUrl
, httpLbs
, simpleHttp
, withHttpEnumerator
) where
import qualified OpenSSL.Session as SSL
import Network.Socket
import qualified Network.Socket.ByteString as B
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Char8 as S8
import Data.Enumerator hiding (head, map, break)
import qualified Data.Enumerator as E
import Network.HTTP.Enumerator.HttpParser
import Control.Exception (throwIO, Exception)
import Control.Arrow (first)
import Data.Char (toLower)
import Control.Monad (forM_)
import Control.Monad.IO.Class (liftIO)
import Control.Failure
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Data.Bits
import Data.Maybe (fromMaybe)
import OpenSSL
withHttpEnumerator :: IO a -> IO a
withHttpEnumerator = withOpenSSL
getSocket :: String -> Int -> IO Socket
getSocket host' port' = do
addrs <- getAddrInfo Nothing (Just host') (Just $ show port')
let addr = head addrs
sock <- socket (addrFamily addr) Stream defaultProtocol
connect sock (addrAddress addr)
return sock
withSocketConn :: String -> Int -> (HttpConn -> IO a) -> IO a
withSocketConn host' port' f = do
sock <- getSocket host' port'
a <- f HttpConn
{ hcRead = B.recv sock
, hcWrite = B.sendAll sock
}
sClose sock
return a
withOpenSslConn :: String -> Int -> (HttpConn -> IO a) -> IO a
withOpenSslConn host' port' f = do
ctx <- SSL.context
sock <- getSocket host' port'
ssl <- SSL.connection ctx sock
SSL.connect ssl
a <- f HttpConn
{ hcRead = SSL.read ssl
, hcWrite = SSL.write ssl
}
SSL.shutdown ssl SSL.Unidirectional
return a
data HttpConn = HttpConn
{ hcRead :: Int -> IO S.ByteString
, hcWrite :: S.ByteString -> IO ()
}
connToEnum :: HttpConn -> Enumerator S.ByteString IO a
connToEnum (HttpConn r _) =
Iteratee . loop
where
loop (Continue k) = do
bs <- r 2
if S.null bs
then return $ Continue k
else do
runIteratee (k $ Chunks [bs]) >>= loop
loop step = return step
data Request = Request
{ host :: S.ByteString
, port :: Int
, secure :: Bool
, requestHeaders :: [(S.ByteString, S.ByteString)]
, path :: S.ByteString
, queryString :: [(S.ByteString, S.ByteString)]
, requestBody :: L.ByteString
, method :: S.ByteString
}
deriving Show
data Response a = Response
{ statusCode :: Int
, responseHeaders :: [(S.ByteString, S.ByteString)]
, responseBody :: a
}
http :: Request -> Iteratee S.ByteString IO a -> IO (Response a)
http req@(Request {..}) bodyIter = do
let h' = S8.unpack host
res <- (if secure then withOpenSslConn else withSocketConn) h' port go
case res of
Left e -> throwIO e
Right x -> return x
where
hh
| port == 80 && not secure = host
| port == 443 && secure = host
| otherwise = host `S.append` S8.pack (':' : show port)
go hc = do
hcWrite hc $ S.concat
$ method
: " "
: path
: renderQS queryString [" HTTP/1.1\r\n"]
let headers' = ("Host", hh)
: ("Content-Length", S8.pack $ show
$ L.length requestBody)
: requestHeaders
forM_ headers' $ \(k, v) -> hcWrite hc $ S.concat
[ k
, ": "
, v
, "\r\n"
]
hcWrite hc "\r\n"
mapM_ (hcWrite hc) $ L.toChunks requestBody
run $ connToEnum hc $$ do
((_, sc, _), hs) <- iterHeaders
let hs' = map (first $ S8.map toLower) hs
let mcl = lookup "content-length" hs'
body' <-
if ("transfer-encoding", "chunked") `elem` hs'
then iterChunks
else case mcl >>= readMay . S8.unpack of
Just len -> takeLBS len
Nothing -> return []
ebody'' <- liftIO $ run $ enumList 1 body' $$ bodyIter
case ebody'' of
Left err -> liftIO $ throwIO err
Right body'' ->
return $ Response
{ statusCode = sc
, responseHeaders = hs
, responseBody = body''
}
takeLBS :: Monad m => Int -> Iteratee S.ByteString m [S.ByteString]
takeLBS 0 = return []
takeLBS len = do
mbs <- E.head
case mbs of
Nothing -> return []
Just bs -> do
let len' = len S.length bs
rest <- takeLBS len'
return $ bs : rest
renderQS :: [(S.ByteString, S.ByteString)]
-> [S.ByteString]
-> [S.ByteString]
renderQS [] x = x
renderQS (p:ps) x =
go "?" p ++ concatMap (go "&") ps ++ x
where
go sep (k, v) = [sep, escape k, "=", escape v]
escape = S8.concatMap (S8.pack . encodeUrlChar)
encodeUrlChar :: Char -> String
encodeUrlChar c
| 'A' <= c && c <= 'Z' = [c]
| 'a' <= c && c <= 'z' = [c]
| '0' <= c && c <= '9' = [c]
encodeUrlChar c@'-' = [c]
encodeUrlChar c@'_' = [c]
encodeUrlChar c@'.' = [c]
encodeUrlChar c@'~' = [c]
encodeUrlChar ' ' = "+"
encodeUrlChar y =
let (a, c) = fromEnum y `divMod` 16
b = a `mod` 16
showHex' x
| x < 10 = toEnum $ x + (fromEnum '0')
| x < 16 = toEnum $ x 10 + (fromEnum 'A')
| otherwise = error $ "Invalid argument to showHex: " ++ show x
in ['%', showHex' b, showHex' c]
data InvalidUrlException = InvalidUrlException String String
deriving (Show, Typeable)
instance Exception InvalidUrlException
parseUrl :: Failure InvalidUrlException m => String -> m Request
parseUrl s@('h':'t':'t':'p':':':'/':'/':rest) = parseUrl1 s False rest
parseUrl s@('h':'t':'t':'p':'s':':':'/':'/':rest) = parseUrl1 s True rest
parseUrl x = failure $ InvalidUrlException x "Invalid scheme"
parseUrl1 :: Failure InvalidUrlException m
=> String -> Bool -> String -> m Request
parseUrl1 full sec s = do
port' <- mport
return Request
{ host = S8.pack hostname
, port = port'
, secure = sec
, requestHeaders = []
, path = S8.pack $ if null path' then "/" else path'
, queryString = parseQueryString $ S8.pack qstring
, requestBody = L.empty
, method = "GET"
}
where
(beforeSlash, afterSlash) = break (== '/') s
(hostname, portStr) = break (== ':') beforeSlash
(path', qstring') = break (== '?') afterSlash
qstring'' = case qstring' of
'?':x -> x
_ -> qstring'
qstring = takeWhile (/= '#') qstring''
mport =
case (portStr, sec) of
("", False) -> return 80
("", True) -> return 443
(':':rest, _) ->
case readMay rest of
Just i -> return i
Nothing -> failure $ InvalidUrlException full "Invalid port"
x -> error $ "parseUrl1: this should never happen: " ++ show x
parseQueryString :: S.ByteString -> [(S.ByteString, S.ByteString)]
parseQueryString = parseQueryString' . dropQuestion
where
dropQuestion q | S.null q || S.head q /= 63 = q
dropQuestion q | otherwise = S.tail q
parseQueryString' q | S.null q = []
parseQueryString' q =
let (x, xs) = breakDiscard 38 q
in parsePair x : parseQueryString' xs
where
parsePair x =
let (k, v) = breakDiscard 61 x
in (qsDecode k, qsDecode v)
qsDecode :: S.ByteString -> S.ByteString
qsDecode z = fst $ S.unfoldrN (S.length z) go z
where
go bs =
case uncons bs of
Nothing -> Nothing
Just (43, ws) -> Just (32, ws)
Just (37, ws) -> Just $ fromMaybe (37, ws) $ do
(x, xs) <- uncons ws
x' <- hexVal x
(y, ys) <- uncons xs
y' <- hexVal y
Just $ (combine x' y', ys)
Just (w, ws) -> Just (w, ws)
hexVal w
| 48 <= w && w <= 57 = Just $ w 48
| 65 <= w && w <= 70 = Just $ w 55
| 97 <= w && w <= 102 = Just $ w 87
| otherwise = Nothing
combine :: Word8 -> Word8 -> Word8
combine a b = shiftL a 4 .|. b
uncons :: S.ByteString -> Maybe (Word8, S.ByteString)
uncons s
| S.null s = Nothing
| otherwise = Just (S.head s, S.tail s)
breakDiscard :: Word8 -> S.ByteString -> (S.ByteString, S.ByteString)
breakDiscard w s =
let (x, y) = S.break (== w) s
in (x, S.drop 1 y)
httpLbs :: Request -> IO (Response L.ByteString)
httpLbs = flip http (L.fromChunks `fmap` consume)
simpleHttp :: String -> IO (Response L.ByteString)
simpleHttp url = parseUrl url >>= httpLbs
readMay :: Read a => String -> Maybe a
readMay s = case reads s of
[] -> Nothing
(x, _):_ -> Just x