module Network.HTTP.Enumerator
(
simpleHttp
, httpLbs
, httpLbsRedirect
, http
, httpRedirect
, redirectIter
, Request (..)
, RequestBody (..)
, Response (..)
, Manager
, newManager
, closeManager
, withManager
, parseUrl
, semiParseUrl
, lbsIter
, urlEncodedBody
, HttpException (..)
) where
import qualified Network.TLS.Client.Enumerator as TLS
import Network (connectTo, PortID (PortNumber))
import qualified Network.Socket as NS
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Char8 as S8
import Data.Enumerator
( Iteratee (..), Stream (..), catchError, throwError
, yield, Step (..), Enumeratee, ($$), joinI, Enumerator, run_
, returnI, (>==>)
)
import qualified Data.Enumerator.List as EL
import Network.HTTP.Enumerator.HttpParser
import Control.Exception (Exception, bracket)
import Control.Arrow (first)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Class (lift)
import Control.Failure
import Data.Typeable (Typeable)
import Codec.Binary.UTF8.String (encodeString)
import qualified Blaze.ByteString.Builder as Blaze
import Blaze.ByteString.Builder.Enumerator (builderToByteString)
import Data.Monoid (Monoid (..))
import qualified Network.HTTP.Types as W
import qualified Data.CaseInsensitive as CI
import Data.Int (Int64)
import qualified Codec.Zlib.Enum as Z
import Control.Monad.IO.Control (MonadControlIO, liftIOOp)
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.IORef as I
import Control.Applicative ((<$>))
import Data.Certificate.X509 (X509)
import Network.TLS.Extra (certificateVerifyChain, certificateVerifyDomain)
getSocket :: String -> Int -> IO NS.Socket
getSocket host' port' = do
let hints = NS.defaultHints {
NS.addrFlags = [NS.AI_ADDRCONFIG]
, NS.addrSocketType = NS.Stream
}
(addr:_) <- NS.getAddrInfo (Just hints) (Just host') (Just $ show port')
sock <- NS.socket (NS.addrFamily addr) (NS.addrSocketType addr)
(NS.addrProtocol addr)
NS.connect sock (NS.addrAddress addr)
return sock
withSocketConn :: MonadIO m
=> Manager
-> String
-> Int
-> Enumerator Blaze.Builder m ()
-> Enumerator S.ByteString m a
withSocketConn man host' port' =
withManagedConn man (host', port', False) $
fmap TLS.socketConn $ getSocket host' port'
withManagedConn
:: MonadIO m
=> Manager
-> ConnKey
-> IO TLS.ConnInfo
-> Enumerator Blaze.Builder m ()
-> Enumerator S.ByteString m a
withManagedConn man key open req step = do
ci <- liftIO $ takeInsecureSocket man key
>>= maybe (liftIO open) return
a <- withCI ci req step
liftIO $ putInsecureSocket man key ci
return a
withSslConn :: MonadIO m
=> ([X509] -> IO Bool)
-> Manager
-> String
-> Int
-> Enumerator Blaze.Builder m ()
-> Enumerator S.ByteString m a
withSslConn checkCert man host' port' =
withManagedConn man (host', port', True) $
(connectTo host' (PortNumber $ fromIntegral port') >>= TLS.sslClientConn checkCert)
withCI :: MonadIO m => TLS.ConnInfo -> Enumerator Blaze.Builder m () -> Enumerator S.ByteString m a
withCI ci req step0 = do
lift $ run_ $ req $$ joinI $ builderToByteString $$ TLS.connIter ci
a <- TLS.connEnum ci step0
return a
data Request m = Request
{ method :: W.Method
, secure :: Bool
, checkCerts :: [X509] -> IO Bool
, host :: W.Ascii
, port :: Int
, path :: W.Ascii
, queryString :: W.Query
, requestHeaders :: W.RequestHeaders
, requestBody :: RequestBody m
}
data RequestBody m
= RequestBodyLBS L.ByteString
| RequestBodyBS S.ByteString
| RequestBodyBuilder Int64 Blaze.Builder
| RequestBodyEnum Int64 (Enumerator Blaze.Builder m ())
data Response = Response
{ statusCode :: Int
, responseHeaders :: W.ResponseHeaders
, responseBody :: L.ByteString
}
deriving (Show, Read, Eq, Typeable)
enumSingle :: Monad m => a -> Enumerator a m b
enumSingle x (Continue k) = k $ Chunks [x]
enumSingle _ step = returnI step
http
:: MonadIO m
=> Request m
-> (W.Status -> W.ResponseHeaders -> Iteratee S.ByteString m a)
-> Manager
-> Iteratee S.ByteString m a
http Request {..} bodyStep m = do
let h' = S8.unpack host
let withConn = if secure then withSslConn checkCerts else withSocketConn
withConn m h' port requestEnum $$ go
where
(contentLength, bodyEnum) =
case requestBody of
RequestBodyLBS lbs -> (L.length lbs, enumSingle $ Blaze.fromLazyByteString lbs)
RequestBodyBS bs -> (fromIntegral $ S.length bs, enumSingle $ Blaze.fromByteString bs)
RequestBodyBuilder i b -> (i, enumSingle b)
RequestBodyEnum i enum -> (i, enum)
hh
| port == 80 && not secure = host
| port == 443 && secure = host
| otherwise = host `mappend` S8.pack (':' : show port)
headers' = ("Host", hh)
: ("Content-Length", S8.pack $ show contentLength)
: ("Accept-Encoding", "gzip")
: requestHeaders
requestHeaders' =
Blaze.fromByteString method
`mappend` Blaze.fromByteString " "
`mappend`
(case S8.uncons path of
Just ('/', _) -> Blaze.fromByteString path
_ -> Blaze.fromByteString "/"
`mappend` Blaze.fromByteString path)
`mappend` (if null queryString
then mempty
else W.renderQueryBuilder True queryString)
`mappend` Blaze.fromByteString " HTTP/1.1\r\n"
`mappend` mconcat (flip map headers' $ \(k, v) ->
Blaze.fromByteString (CI.original k)
`mappend` Blaze.fromByteString ": "
`mappend` Blaze.fromByteString v
`mappend` Blaze.fromByteString "\r\n")
`mappend` Blaze.fromByteString "\r\n"
requestEnum = enumSingle requestHeaders' >==> bodyEnum
go = do
((_, sc, sm), hs) <- iterHeaders
let s = W.Status sc sm
let hs' = map (first CI.mk) hs
let mcl = lookup "content-length" hs'
let body' x =
if ("transfer-encoding", "chunked") `elem` hs'
then joinI $ chunkedEnumeratee $$ x
else case mcl >>= readMay . S8.unpack of
Just len -> joinI $ takeLBS len $$ x
Nothing -> x
let decompress x =
if ("content-encoding", "gzip") `elem` hs'
then joinI $ Z.ungzip x
else returnI x
if method == "HEAD"
then bodyStep s hs'
else body' $ decompress $$ bodyStep s hs'
chunkedEnumeratee :: MonadIO m => Enumeratee S.ByteString S.ByteString m a
chunkedEnumeratee k@(Continue _) = do
len <- catchParser "Chunk header" iterChunkHeader
if len == 0
then return k
else do
k' <- takeLBS len k
catchParser "End of chunk newline" iterNewline
chunkedEnumeratee k'
chunkedEnumeratee step = return step
takeLBS :: MonadIO m => Int -> Enumeratee S.ByteString S.ByteString m a
takeLBS 0 step = return step
takeLBS len (Continue k) = do
mbs <- EL.head
case mbs of
Nothing -> return $ Continue k
Just bs -> do
let (len', chunk, rest) =
if S.length bs > len
then (0, S.take len bs,
if S.length bs == len
then Chunks []
else Chunks [S.drop len bs])
else (len S.length bs, bs, Chunks [])
step' <- lift $ runIteratee $ k $ Chunks [chunk]
if len' == 0
then yield step' rest
else takeLBS len' step'
takeLBS _ step = return step
encodeUrlCharPI :: Char -> String
encodeUrlCharPI '/' = "/"
encodeUrlCharPI c = encodeUrlChar c
encodeUrlChar :: Char -> String
encodeUrlChar c
| 'A' <= c && c <= 'Z' = [c]
| 'a' <= c && c <= 'z' = [c]
| '0' <= c && c <= '9' = [c]
encodeUrlChar c@'-' = [c]
encodeUrlChar c@'_' = [c]
encodeUrlChar c@'.' = [c]
encodeUrlChar c@'~' = [c]
encodeUrlChar ' ' = "+"
encodeUrlChar y =
let (a, c) = fromEnum y `divMod` 16
b = a `mod` 16
showHex' x
| x < 10 = toEnum $ x + (fromEnum '0')
| x < 16 = toEnum $ x 10 + (fromEnum 'A')
| otherwise = error $ "Invalid argument to showHex: " ++ show x
in ['%', showHex' b, showHex' c]
parseUrl :: Failure HttpException m => String -> m (Request m')
parseUrl = parseUrlHelper True
semiParseUrl :: Failure HttpException m => String -> m (Request m')
semiParseUrl = parseUrlHelper False
parseUrlHelper :: Failure HttpException m => Bool -> String -> m (Request m')
parseUrlHelper parsePath s@('h':'t':'t':'p':':':'/':'/':rest) = parseUrl1 s False parsePath rest
parseUrlHelper parsePath s@('h':'t':'t':'p':'s':':':'/':'/':rest) = parseUrl1 s True parsePath rest
parseUrlHelper _ x = failure $ InvalidUrlException x "Invalid scheme"
parseUrl1 :: Failure HttpException m
=> String -> Bool -> Bool -> String -> m (Request m')
parseUrl1 full sec parsePath s =
parseUrl2 full sec parsePath s'
where
s' = encodeString s
parseUrl2 :: Failure HttpException m
=> String -> Bool -> Bool -> String -> m (Request m')
parseUrl2 full sec parsePath s = do
port' <- mport
return Request
{ host = S8.pack hostname
, port = port'
, secure = sec
, checkCerts = \x ->
if certificateVerifyDomain hostname x
then certificateVerifyChain x
else return False
, requestHeaders = []
, path = S8.pack
$ if null path''
then "/"
else concatMap encodeUrlCharPI path''
, queryString = if parsePath
then W.parseQuery $ S8.pack qstring
else []
, requestBody = RequestBodyLBS L.empty
, method = "GET"
}
where
(beforeSlash, afterSlash) = break (== '/') s
(hostname, portStr) = break (== ':') beforeSlash
(path', qstring') = break (== '?') afterSlash
path'' = if parsePath then path' else afterSlash
qstring'' = case qstring' of
'?':x -> x
_ -> qstring'
qstring = takeWhile (/= '#') qstring''
mport =
case (portStr, sec) of
("", False) -> return 80
("", True) -> return 443
(':':rest, _) ->
case readMay rest of
Just i -> return i
Nothing -> failure $ InvalidUrlException full "Invalid port"
x -> error $ "parseUrl1: this should never happen: " ++ show x
lbsIter :: Monad m => W.Status -> W.ResponseHeaders
-> Iteratee S.ByteString m Response
lbsIter (W.Status sc _) hs = do
lbs <- fmap L.fromChunks EL.consume
return $ Response sc hs lbs
httpLbs :: MonadIO m => Request m -> Manager -> m Response
httpLbs req = run_ . http req lbsIter
simpleHttp :: (MonadIO m, Failure HttpException m) => String -> m L.ByteString
simpleHttp url = do
url' <- parseUrlHelper False url
Response sc _ b <- liftIO $ withManager $ httpLbsRedirect url'
if 200 <= sc && sc < 300
then return b
else failure $ StatusCodeException sc b
data HttpException = StatusCodeException Int L.ByteString
| InvalidUrlException String String
| TooManyRedirects
| HttpParserException String
deriving (Show, Typeable)
instance Exception HttpException
httpRedirect
:: (MonadIO m, Failure HttpException m)
=> Request m
-> (W.Status -> W.ResponseHeaders -> Iteratee S.ByteString m a)
-> Manager
-> Iteratee S.ByteString m a
httpRedirect req bodyStep manager =
http req (redirectIter 10 req bodyStep manager) manager
redirectIter :: (MonadIO m, Failure HttpException m)
=> Int
-> Request m
-> (W.Status -> W.ResponseHeaders -> Iteratee S.ByteString m a)
-> Manager
-> (W.Status -> W.ResponseHeaders -> Iteratee S.ByteString m a)
redirectIter redirects req bodyStep manager s@(W.Status code _) hs
| 300 <= code && code < 400 =
case lookup "location" hs of
Just l'' -> do
let l' =
case S8.uncons l'' of
Just ('/', _) -> concat
[ "http"
, if secure req then "s" else ""
, "://"
, S8.unpack $ host req
, ":"
, show $ port req
, S8.unpack l''
]
_ -> S8.unpack l''
l <- lift $ parseUrl l'
let req' = req
{ host = host l
, port = port l
, secure = secure l
, path = path l
, queryString = queryString l
, method =
if code == 303
then "GET"
else method l
}
if redirects == 0
then lift $ failure TooManyRedirects
else (http req') (redirectIter (redirects 1) req' bodyStep manager) manager
Nothing -> bodyStep s hs
| otherwise = bodyStep s hs
httpLbsRedirect :: (MonadIO m, Failure HttpException m) => Request m -> Manager -> m Response
httpLbsRedirect req = run_ . httpRedirect req lbsIter
readMay :: Read a => String -> Maybe a
readMay s = case reads s of
[] -> Nothing
(x, _):_ -> Just x
urlEncodedBody :: Monad m => [(S.ByteString, S.ByteString)] -> Request m' -> Request m
urlEncodedBody headers req = req
{ requestBody = RequestBodyLBS body
, method = "POST"
, requestHeaders =
(ct, "application/x-www-form-urlencoded")
: filter (\(x, _) -> x /= ct) (requestHeaders req)
}
where
ct = "Content-Type"
body = L.fromChunks . return $ W.renderSimpleQuery False headers
catchParser :: Monad m => String -> Iteratee a m b -> Iteratee a m b
catchParser s i = catchError i (const $ throwError $ HttpParserException s)
data Manager = Manager
{ mConns :: I.IORef (Map ConnKey TLS.ConnInfo)
}
type ConnKey = (String, Int, Bool)
takeInsecureSocket :: Manager -> ConnKey -> IO (Maybe TLS.ConnInfo)
takeInsecureSocket man key =
I.atomicModifyIORef (mConns man) go
where
go m = (Map.delete key m, Map.lookup key m)
putInsecureSocket :: Manager -> ConnKey -> TLS.ConnInfo -> IO ()
putInsecureSocket man key ci = do
msock <- I.atomicModifyIORef (mConns man) go
maybe (return ()) TLS.connClose msock
where
go m = (Map.insert key ci m, Map.lookup key m)
newManager :: IO Manager
newManager = Manager <$> I.newIORef Map.empty
closeManager :: Manager -> IO ()
closeManager (Manager i) = do
m <- I.atomicModifyIORef i $ \x -> (Map.empty, x)
mapM_ TLS.connClose $ Map.elems m
withManager :: MonadControlIO m => (Manager -> m a) -> m a
withManager = liftIOOp $ bracket newManager closeManager