module Network.Wai.Handler.Warp.Conduit where import Control.Applicative import Control.Exception import Control.Monad (unless) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Data.ByteString (ByteString) import Data.ByteString.Lazy.Char8 (pack) import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L import Data.Conduit import qualified Data.Conduit.Binary as CB import Data.Conduit.Internal (ResumableSource (..)) import qualified Data.Conduit.List as CL import qualified Data.IORef as I import Data.Word (Word, Word8) import Network.Wai.Handler.Warp.Types ---------------------------------------------------------------- -- | Contains a @Source@ and a byte count that is still to be read in. newtype IsolatedBSSource = IsolatedBSSource (I.IORef (Int, ResumableSource (ResourceT IO) ByteString)) -- | Given an @IsolatedBSSource@ provide a @Source@ that only allows up to the -- specified number of bytes to be passed downstream. All leftovers should be -- retained within the @Source@. If there are not enough bytes available, -- throws a @ConnectionClosedByPeer@ exception. ibsIsolate :: IsolatedBSSource -> Source (ResourceT IO) ByteString ibsIsolate ibs@(IsolatedBSSource ref) = do (count, src) <- liftIO $ I.readIORef ref unless (count == 0) $ do -- Get the next chunk (if available) and the updated source (src', mbs) <- lift $ src $$++ CL.head -- If no chunk available, then there aren't enough bytes in the -- stream. Throw a ConnectionClosedByPeer bs <- maybe (liftIO $ throwIO ConnectionClosedByPeer) return mbs let -- How many of the bytes in this chunk to send downstream toSend = min count (S.length bs) -- How many bytes will still remain to be sent downstream count' = count - toSend case () of () -- The expected count is greater than the size of the -- chunk we just read. Send the entire chunk -- downstream, and then loop on this function for the -- next chunk. | count' > 0 -> do liftIO $ I.writeIORef ref (count', src') yield bs ibsIsolate ibs -- The expected count is the total size of the chunk we -- just read. Send this chunk downstream, and then -- terminate the stream. | count == S.length bs -> do liftIO $ I.writeIORef ref (count', src') yield bs -- Some of the bytes in this chunk should not be sent -- downstream. Split up the chunk into the sent and -- not-sent parts, add the not-sent parts onto the new -- source, and send the rest of the chunk downstream. | otherwise -> do let (x, y) = S.splitAt toSend bs liftIO $ I.writeIORef ref (count', fmapResume (yield y >>) src') yield x -- | Extract the underlying @Source@ from an @IsolatedBSSource@, which will not -- perform any more isolation. ibsDone :: IsolatedBSSource -> IO (ResumableSource (ResourceT IO) ByteString) ibsDone (IsolatedBSSource ref) = snd <$> I.readIORef ref ---------------------------------------------------------------- data ChunkState = NeedLen | NeedLenNewline | HaveLen Word bsCRLF :: L.ByteString bsCRLF = pack "\r\n" chunkedSource :: MonadIO m => I.IORef (ResumableSource m ByteString, ChunkState) -> Source m ByteString chunkedSource ipair = do (src, mlen) <- liftIO $ I.readIORef ipair go src mlen where go' src front = do (src', (len, bs)) <- lift $ src $$++ front getLen let src'' | S.null bs = src' | otherwise = fmapResume (yield bs >>) src' go src'' $ HaveLen len go src NeedLen = go' src id go src NeedLenNewline = go' src (CB.take 2 >>) go src (HaveLen 0) = do -- Drop the final CRLF (src', ()) <- lift $ src $$++ do crlf <- CB.take 2 unless (crlf == bsCRLF) $ leftover $ S.concat $ L.toChunks crlf liftIO $ I.writeIORef ipair (src', HaveLen 0) go src (HaveLen len) = do (src', mbs) <- lift $ src $$++ CL.head case mbs of Nothing -> liftIO $ I.writeIORef ipair (src', HaveLen 0) Just bs -> case S.length bs `compare` fromIntegral len of EQ -> yield' src' NeedLenNewline bs LT -> do let mlen = HaveLen $ len - fromIntegral (S.length bs) yield' src' mlen bs GT -> do let (x, y) = S.splitAt (fromIntegral len) bs let src'' = fmapResume (yield y >>) src' yield' src'' NeedLenNewline x yield' src mlen bs = do liftIO $ I.writeIORef ipair (src, mlen) yield bs go src mlen getLen :: Monad m => Sink ByteString m (Word, ByteString) getLen = do mbs <- CL.head case mbs of Nothing -> return (0, S.empty) Just bs -> do (x, y) <- case S.breakByte 10 bs of (x, y) | S.null y -> do mbs2 <- CL.head case mbs2 of Nothing -> return (x, y) Just bs2 -> return $ S.breakByte 10 $ bs `S.append` bs2 | otherwise -> return (x, y) let w = S.foldl' (\i c -> i * 16 + fromIntegral (hexToWord c)) 0 $ S.takeWhile isHexDigit x return (w, S.drop 1 y) hexToWord w | w < 58 = w - 48 | w < 71 = w - 55 | otherwise = w - 87 isHexDigit :: Word8 -> Bool isHexDigit w = w >= 48 && w <= 57 || w >= 65 && w <= 70 || w >= 97 && w <= 102 ---------------------------------------------------------------- fmapResume :: (Source m o1 -> Source m o2) -> ResumableSource m o1 -> ResumableSource m o2 fmapResume f (ResumableSource src m) = ResumableSource (f src) m