{-# LANGUAGE OverloadedStrings #-} -- | This module contains many helper functions, as well the code for 'Source', -- which is a pretty important structure module Network.MiniHTTP.HTTPConnection ( -- * Sources, and related functions Source , SourceResult(..) , bsSource , hSource , nullSource , sourceToLBS , sourceToBS , connSource , connChunkedSource , connEOFSource , sourceDrain , streamSource , streamSourceChunked -- * Misc functions , readIG , sourceIG , maybeRead , sslToBaseConnection ) where import Control.Concurrent.STM import qualified Control.Exception as Exception import Control.Monad (liftM) import qualified Data.ByteString as B import Data.ByteString.Char8 () import Data.ByteString.Internal (c2w, w2c) import qualified Data.ByteString.Lazy.Internal as BL import qualified Data.Binary.Strict.IncrementalGet as IG import Data.IORef import Data.Int (Int64) import System.IO import Text.Printf (printf) import System.IO.Unsafe (unsafeInterleaveIO) import qualified Network.Connection as C import Network.Socket as Socket import Network.MiniHTTP.Marshal (parseChunkHeader) import qualified OpenSSL.Session as SSL -- | A source is a stream of data, like a lazy data structure, but without -- some of the dangers that such entail. A source returns a 'SourceResult' -- each time you evaluate it. type Source = IO SourceResult data SourceResult = SourceError -- ^ error - please don't read this source again | SourceEOF -- ^ end of data | SourceData B.ByteString -- ^ some data deriving (Show) -- | Construct a source from a ByteString bsSource :: B.ByteString -> IO Source bsSource bs = do ref <- newIORef $ SourceData bs return $ do v <- readIORef ref writeIORef ref SourceEOF return v -- | Construct a source from a Handle hSource :: (Int64, Int64) -- ^ the first and last byte to include -> Handle -- ^ the handle to read from -> IO Source hSource (from, to) handle = do bytesSoFar <- newIORef (from :: Int64) hSeek handle AbsoluteSeek (fromIntegral from) return $ do Exception.catch (do done <- readIORef bytesSoFar bytes <- B.hGet handle $ min (128 * 1024) (fromIntegral $ (to + 1) - done) if B.length bytes == 0 then do if to + 1 == done then return SourceEOF else return SourceError else do modifyIORef bytesSoFar ((+) (fromIntegral $ B.length bytes)) return $ SourceData bytes) (const $ return SourceError) -- | A source with no data (e.g. @/dev/null@) nullSource :: Source nullSource = return SourceEOF -- | A source which reads from the given 'Connection' until the connection -- signals end-of-file. connEOFSource :: C.Connection -> IO Source connEOFSource conn = do return $ catch (liftM SourceData $ C.read conn 1024) (const $ return SourceEOF) -- | A source which reads from a 'C.Connection' connSource :: Int64 -- ^ the number of bytes to read -> B.ByteString -- ^ a string which is prepended to the output -> C.Connection -- ^ the connection to read from -> IO Source connSource n bs conn = if fromIntegral (B.length bs) == n then bsSource bs else do ref <- newIORef (False, 0 :: Int64) return $ do (doneBS, n') <- readIORef ref if not doneBS then do writeIORef ref (True, fromIntegral $ B.length bs) return $ SourceData bs else if n' == n then return SourceEOF else do bytes <- C.read conn $ min (32 * 1024) $ fromIntegral (n - n') if B.length bytes == 0 then return SourceError else do writeIORef ref (doneBS, n' + (fromIntegral $ B.length bytes)) return $ SourceData bytes -- | A source which reads an HTTP chunked reply from a 'C.Connection' connChunkedSource :: C.Connection -> IO Source connChunkedSource conn = do -- the contents of this reference are the number of bytes remaining in the -- current chunk. If zero, a chunk headers needs to be read. If < 0, we have -- hit EOF. If we read the end of a chunk, we always read the trailing \r\n -- before returning (so one need never consider that case on entry) ref <- newIORef (0 :: Int64) let f = do remainingInThisChunk <- readIORef ref case remainingInThisChunk of 0 -> do m <- readIG conn 16 256 parseChunkHeader case m of Nothing -> return SourceError Just n -> if n == 0 then C.reada conn 2 >> writeIORef ref (-1) >> return SourceEOF else writeIORef ref n >> f (-1) -> return SourceEOF remainingInThisChunk -> do bytes <- C.read conn $ fromIntegral $ min remainingInThisChunk $ 32*1024 if B.null bytes then return SourceError else do let stillRemaining = remainingInThisChunk - (fromIntegral $ B.length bytes) if stillRemaining == 0 then do C.reada conn 2 -- read \r\n writeIORef ref 0 else do writeIORef ref stillRemaining return $ SourceData bytes return f -- | Read a source until it returns 'SourceEOF' sourceDrain :: Source -> IO () sourceDrain s = do v <- s case v of SourceEOF -> return () SourceError -> return () SourceData _ -> sourceDrain s -- | Convert a source to a lazy ByteString sourceToLBS :: Source -> IO BL.ByteString sourceToLBS s = do bytes <- s case bytes of SourceEOF -> return $ BL.Empty SourceError -> fail "Error in reading from client" SourceData bs -> do rest <- unsafeInterleaveIO $ sourceToLBS s return $ BL.Chunk bs rest -- | Take, at most, the first n bytes from a Source and return a strict -- ByteString. Returns Nothing on error. (A short read is not an error) sourceToBS :: Int -> Source -> IO (Maybe B.ByteString) sourceToBS n source = f 0 where f soFar = do s <- source case s of SourceEOF -> return $ Just B.empty SourceError -> return Nothing SourceData bs -> do if B.length bs + soFar >= n then return $ Just $ B.take (n - soFar) bs else do rest <- f (soFar + B.length bs) return $ (rest >>= return . B.append bs) -- | Stream a source to a connection while not enqueuing more than lowWater -- bytes in the outbound queue (not inc the kernel buffer) streamSource :: Int -> C.Connection -> Source -> IO Bool streamSource lowWater conn source = do next <- source case next of SourceEOF -> return True SourceError -> return False SourceData bs -> do atomically $ C.writeAtLowWater lowWater conn bs streamSource lowWater conn source -- | Stream a source to a connection, with chunked encoding, while not -- enqueuing more than lowWater bytes in the outbound queue (not inc the -- kernel buffer) streamSourceChunked :: Int -> C.Connection -> Source -> IO Bool streamSourceChunked lowWater conn source = do next <- source case next of SourceEOF -> do atomically $ C.writeAtLowWater lowWater conn "0\r\n\r\n" return True SourceError -> return False SourceData bs -> do atomically $ C.writeAtLowWater lowWater conn $ B.pack $ map c2w $ printf "%d\r\n\r\n" $ B.length bs atomically $ C.writeAtLowWater lowWater conn bs atomically $ C.writeAtLowWater lowWater conn "\r\n" streamSourceChunked lowWater conn source -- | Run an incremental parser from the network readIG :: C.Connection -- ^ the connection to read from -> Int -- ^ the block size to use -> Int -- ^ maximum number of bytes to parse -> IG.Get a a -- ^ the parser -> IO (Maybe a) readIG conn blockSize maxBytes parser = do let f sofar result | sofar >= maxBytes = return Nothing | otherwise = do case result of IG.Failed _ -> return Nothing IG.Partial cont -> C.read conn blockSize >>= \bs -> f (sofar + B.length bs) $ cont bs IG.Finished rest result -> do atomically $ C.pushBack conn rest return $ Just result C.read conn blockSize >>= f 0 . IG.runGet parser -- | Run an incremental parser from a 'Source' sourceIG :: Source -- ^ the source to read from -> Int -- ^ the maximum number of bytes to parse -> IG.Get a a -- ^ the parser -> IO (Maybe a) sourceIG source maxBytes parser = do let f sofar result | sofar >= maxBytes = return Nothing | otherwise = case result of IG.Failed _ -> return Nothing IG.Partial cont -> do s <- source case s of SourceError -> return Nothing SourceEOF -> return Nothing SourceData bytes -> f (B.length bytes + sofar) $ cont bytes IG.Finished _ result -> do return $ Just result f 0 $ IG.runGet parser B.empty maybeRead :: Read a => B.ByteString -> Maybe a maybeRead s = case reads $ map w2c $ B.unpack s of [(x, "")] -> Just x _ -> Nothing -- | Convert an SSL connection to a BaseConnection for Network.Connection sslToBaseConnection :: SSL.SSL -> C.BaseConnection sslToBaseConnection ssl = C.BaseConnection r w c where r n = do bytes <- SSL.read ssl n return bytes w bs = SSL.write ssl bs >> return (B.length bs) c = SSL.shutdown ssl SSL.Unidirectional >> sClose (SSL.sslSocket ssl)