{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE RankNTypes #-} -- | Some helpers for parsing data out of a raw WAI 'Request'. -- -- copy from wai-extra 3.0.4.5. module Network.Wai.Parse ( parseHttpAccept , parseRequestBody , RequestBodyType (..) , getRequestBodyType , sinkRequestBody , BackEnd , lbsBackEnd , Param , File , FileInfo (..) , parseContentType #if TEST , Bound (..) , findBound , sinkTillBound , killCR , killCRLF , takeLine #endif ) where import qualified Data.ByteString.Search as Search import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Char8 as S8 import Data.Word (Word8) import Data.Maybe (fromMaybe) import Data.List (sortBy) import Data.Function (on) import Network.Wai import qualified Network.HTTP.Types as H import Control.Monad (when, unless) import Data.IORef breakDiscard :: Word8 -> S.ByteString -> (S.ByteString, S.ByteString) breakDiscard w s = let (x, y) = S.break (== w) s in (x, S.drop 1 y) -- | Parse the HTTP accept string to determine supported content types. parseHttpAccept :: S.ByteString -> [S.ByteString] parseHttpAccept = map fst . sortBy (rcompare `on` snd) . map (addSpecificity . grabQ) . S.split 44 -- comma where rcompare :: (Double,Int) -> (Double,Int) -> Ordering rcompare = flip compare addSpecificity (s, q) = -- Prefer higher-specificity types let semicolons = S.count 0x3B s stars = S.count 0x2A s in (s, (q, semicolons - stars)) grabQ s = -- Stripping all spaces may be too harsh. -- Maybe just strip either side of semicolon? let (s', q) = S.breakSubstring ";q=" (S.filter (/=0x20) s) -- 0x20 is space q' = S.takeWhile (/=0x3B) (S.drop 3 q) -- 0x3B is semicolon in (s', readQ q') readQ s = case reads $ S8.unpack s of (x, _):_ -> x _ -> 1.0 -- | Store uploaded files in memory lbsBackEnd :: Monad m => ignored1 -> ignored2 -> m S.ByteString -> m L.ByteString lbsBackEnd _ _ popper = loop id where loop front = do bs <- popper if S.null bs then return $ L.fromChunks $ front [] else loop $ front . (bs:) -- | Information on an uploaded file. data FileInfo c = FileInfo { fileName :: S.ByteString , fileContentType :: S.ByteString , fileContent :: c } deriving (Eq, Show) -- | Post parameter name and value. type Param = (S.ByteString, S.ByteString) -- | Post parameter name and associated file information. type File y = (S.ByteString, FileInfo y) -- | A file uploading backend. Takes the parameter name, file name, and a -- stream of data. type BackEnd a = S.ByteString -- ^ parameter name -> FileInfo () -> IO S.ByteString -> IO a data RequestBodyType = UrlEncoded | Multipart S.ByteString getRequestBodyType :: Request -> Maybe RequestBodyType getRequestBodyType req = do ctype' <- lookup "Content-Type" $ requestHeaders req let (ctype, attrs) = parseContentType ctype' case ctype of "application/x-www-form-urlencoded" -> return UrlEncoded "multipart/form-data" | Just bound <- lookup "boundary" attrs -> return $ Multipart bound _ -> Nothing -- | Parse a content type value, turning a single @ByteString@ into the actual -- content type and a list of pairs of attributes. -- -- Since 1.3.2 parseContentType :: S.ByteString -> (S.ByteString, [(S.ByteString, S.ByteString)]) parseContentType a = do let (ctype, b) = S.break (== semicolon) a attrs = goAttrs id $ S.drop 1 b in (ctype, attrs) where semicolon = 59 equals = 61 space = 32 goAttrs front bs | S.null bs = front [] | otherwise = let (x, rest) = S.break (== semicolon) bs in goAttrs (front . (goAttr x:)) $ S.drop 1 rest goAttr bs = let (k, v') = S.break (== equals) bs v = S.drop 1 v' in (strip k, strip v) strip = S.dropWhile (== space) . fst . S.breakEnd (/= space) parseRequestBody :: BackEnd y -> Request -> IO ([Param], [File y]) parseRequestBody s r = case getRequestBodyType r of Nothing -> return ([], []) Just rbt -> sinkRequestBody s rbt (requestBody r) sinkRequestBody :: BackEnd y -> RequestBodyType -> IO S.ByteString -> IO ([Param], [File y]) sinkRequestBody s r body = do ref <- newIORef (id, id) let add x = atomicModifyIORef ref $ \(y, z) -> case x of Left y' -> ((y . (y':), z), ()) Right z' -> ((y, z . (z':)), ()) conduitRequestBody s r body add (x, y) <- readIORef ref return (x [], y []) conduitRequestBody :: BackEnd y -> RequestBodyType -> IO S.ByteString -> (Either Param (File y) -> IO ()) -> IO () conduitRequestBody _ UrlEncoded rbody add = do -- NOTE: in general, url-encoded data will be in a single chunk. -- Therefore, I'm optimizing for the usual case by sticking with -- strict byte strings here. let loop front = do bs <- rbody if S.null bs then return $ S.concat $ front [] else loop $ front . (bs:) bs <- loop id mapM_ (add . Left) $ H.parseSimpleQuery bs conduitRequestBody backend (Multipart bound) rbody add = parsePieces backend (S8.pack "--" `S.append` bound) rbody add takeLine :: Source -> IO (Maybe S.ByteString) takeLine src = go id where go front = do bs <- readSource src if S.null bs then close front else push front bs close front = leftover src (front S.empty) >> return Nothing push front bs = do let (x, y) = S.break (== 10) $ front bs -- LF in if S.null y then go $ S.append x else do when (S.length y > 1) $ leftover src $ S.drop 1 y return $ Just $ killCR x takeLines :: Source -> IO [S.ByteString] takeLines src = do res <- takeLine src case res of Nothing -> return [] Just l | S.null l -> return [] | otherwise -> do ls <- takeLines src return $ l : ls data Source = Source (IO S.ByteString) (IORef S.ByteString) mkSource :: IO S.ByteString -> IO Source mkSource f = do ref <- newIORef S.empty return $ Source f ref readSource :: Source -> IO S.ByteString readSource (Source f ref) = do bs <- atomicModifyIORef ref $ \bs -> (S.empty, bs) if S.null bs then f else return bs leftover :: Source -> S.ByteString -> IO () leftover (Source _ ref) bs = writeIORef ref bs parsePieces :: BackEnd y -> S.ByteString -> IO S.ByteString -> (Either Param (File y) -> IO ()) -> IO () parsePieces sink bound rbody add = mkSource rbody >>= loop where loop src = do _boundLine <- takeLine src res' <- takeLines src unless (null res') $ do let ls' = map parsePair res' let x = do cd <- lookup contDisp ls' let ct = lookup contType ls' let attrs = parseAttrs cd name <- lookup "name" attrs return (ct, name, lookup "filename" attrs) case x of Just (mct, name, Just filename) -> do let ct = fromMaybe "application/octet-stream" mct fi0 = FileInfo filename ct () (wasFound, y) <- sinkTillBound' bound name fi0 sink src add $ Right (name, fi0 { fileContent = y }) when wasFound (loop src) Just (_ct, name, Nothing) -> do let seed = id let iter front bs = return $ front . (:) bs (wasFound, front) <- sinkTillBound bound iter seed src let bs = S.concat $ front [] let x' = (name, bs) add $ Left x' when wasFound (loop src) _ -> do -- ignore this part let seed = () iter () _ = return () (wasFound, ()) <- sinkTillBound bound iter seed src when wasFound (loop src) where contDisp = S8.pack "Content-Disposition" contType = S8.pack "Content-Type" parsePair s = let (x, y) = breakDiscard 58 s -- colon in (x, S.dropWhile (== 32) y) -- space data Bound = FoundBound S.ByteString S.ByteString | NoBound | PartialBound deriving (Eq, Show) findBound :: S.ByteString -> S.ByteString -> Bound findBound b bs = handleBreak $ Search.breakOn b bs where handleBreak (h, t) | S.null t = go [lowBound..S.length bs - 1] | otherwise = FoundBound h $ S.drop (S.length b) t lowBound = max 0 $ S.length bs - S.length b go [] = NoBound go (i:is) | mismatch [0..S.length b - 1] [i..S.length bs - 1] = go is | otherwise = let endI = i + S.length b in if endI > S.length bs then PartialBound else FoundBound (S.take i bs) (S.drop endI bs) mismatch [] _ = False mismatch _ [] = False mismatch (x:xs) (y:ys) | S.index b x == S.index bs y = mismatch xs ys | otherwise = True sinkTillBound' :: S.ByteString -> S.ByteString -> FileInfo () -> BackEnd y -> Source -> IO (Bool, y) sinkTillBound' bound name fi sink src = do (next, final) <- wrapTillBound bound src y <- sink name fi next b <- final return (b, y) data WTB = WTBWorking (S.ByteString -> S.ByteString) | WTBDone Bool wrapTillBound :: S.ByteString -- ^ bound -> Source -> IO (IO S.ByteString, IO Bool) -- ^ Bool indicates if the bound was found wrapTillBound bound src = do ref <- newIORef $ WTBWorking id return (go ref, final ref) where final ref = do x <- readIORef ref case x of WTBWorking _ -> error "wrapTillBound did not finish" WTBDone y -> return y go ref = do state <- readIORef ref case state of WTBDone _ -> return S.empty WTBWorking front -> do bs <- readSource src if S.null bs then do writeIORef ref $ WTBDone False return $ front bs else push $ front bs where push bs = case findBound bound bs of FoundBound before after -> do let before' = killCRLF before leftover src after writeIORef ref $ WTBDone True return before' NoBound -> do -- don't emit newlines, in case it's part of a bound let (toEmit, front') = if not (S8.null bs) && S8.last bs `elem` ['\r','\n'] then let (x, y) = S.splitAt (S.length bs - 2) bs in (x, S.append y) else (bs, id) writeIORef ref $ WTBWorking front' if S.null toEmit then go ref else return toEmit PartialBound -> do writeIORef ref $ WTBWorking $ S.append bs go ref sinkTillBound :: S.ByteString -> (x -> S.ByteString -> IO x) -> x -> Source -> IO (Bool, x) sinkTillBound bound iter seed0 src = do (next, final) <- wrapTillBound bound src let loop seed = do bs <- next if S.null bs then return seed else iter seed bs >>= loop seed <- loop seed0 b <- final return (b, seed) parseAttrs :: S.ByteString -> [(S.ByteString, S.ByteString)] parseAttrs = map go . S.split 59 -- semicolon where tw = S.dropWhile (== 32) -- space dq s = if S.length s > 2 && S.head s == 34 && S.last s == 34 -- quote then S.tail $ S.init s else s go s = let (x, y) = breakDiscard 61 s -- equals sign in (tw x, dq $ tw y) killCRLF :: S.ByteString -> S.ByteString killCRLF bs | S.null bs || S.last bs /= 10 = bs -- line feed | otherwise = killCR $ S.init bs killCR :: S.ByteString -> S.ByteString killCR bs | S.null bs || S.last bs /= 13 = bs -- carriage return | otherwise = S.init bs