module Snap.Util.FileUploads
(
handleFileUploads
, handleMultipart
, PartInfo(..)
, UploadPolicy
, defaultUploadPolicy
, doProcessFormInputs
, setProcessFormInputs
, getMaximumFormInputSize
, setMaximumFormInputSize
, getMinimumUploadRate
, setMinimumUploadRate
, getMinimumUploadSeconds
, setMinimumUploadSeconds
, getUploadTimeout
, setUploadTimeout
, PartUploadPolicy
, disallow
, allowWithMaximumSize
, FileUploadException
, fileUploadExceptionReason
, BadPartException
, badPartExceptionReason
, PolicyViolationException
, policyViolationExceptionReason
) where
import Control.Arrow
import Control.Applicative
import Control.Exception (SomeException(..))
import Control.Monad
import Control.Monad.CatchIO
import Control.Monad.Trans
import qualified Data.Attoparsec.Char8 as Atto
import Data.Attoparsec.Char8 hiding (many, Result(..))
import Data.Attoparsec.Enumerator
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Internal (c2w)
import qualified Data.CaseInsensitive as CI
import qualified Data.DList as D
import Data.Enumerator.Binary (iterHandle)
import Data.IORef
import Data.Int
import Data.List hiding (takeWhile)
import qualified Data.Map as Map
import Data.Maybe
import qualified Data.Text as T
import Data.Text (Text)
import qualified Data.Text.Encoding as TE
import Data.Typeable
import Prelude hiding (catch, getLine, takeWhile)
import System.Directory
import System.IO hiding (isEOF)
import Snap.Iteratee hiding (map)
import qualified Snap.Iteratee as I
import Snap.Internal.Debug
import Snap.Internal.Iteratee.Debug
import Snap.Internal.Iteratee.BoyerMooreHorspool
import Snap.Internal.Parsing
import Snap.Types
handleFileUploads ::
(MonadSnap m) =>
FilePath
-> UploadPolicy
-> (PartInfo -> PartUploadPolicy)
-> ([(PartInfo, Either PolicyViolationException FilePath)] -> m a)
-> m a
handleFileUploads tmpdir uploadPolicy partPolicy handler = do
uploadedFiles <- newUploadedFiles
(do
xs <- handleMultipart uploadPolicy (iter uploadedFiles)
handler xs
) `finally` (cleanupUploadedFiles uploadedFiles)
where
iter uploadedFiles partInfo = maybe disallowed takeIt mbFs
where
ctText = partContentType partInfo
fnText = fromMaybe "" $ partFileName partInfo
ct = TE.decodeUtf8 ctText
fn = TE.decodeUtf8 fnText
(PartUploadPolicy mbFs) = partPolicy partInfo
retVal (_,x) = (partInfo, Right x)
takeIt maxSize = do
let it = fmap retVal $
joinI' $
takeNoMoreThan maxSize $$
fileReader uploadedFiles tmpdir partInfo
it `catches` [ Handler $ \(_ :: TooManyBytesReadException) ->
(skipToEof >> tooMany maxSize)
, Handler $ \(e :: SomeException) -> throw e
]
tooMany maxSize =
return ( partInfo
, Left $ PolicyViolationException
$ T.concat [ "File \""
, fn
, "\" exceeded maximum allowable size "
, T.pack $ show maxSize ] )
disallowed =
return ( partInfo
, Left $ PolicyViolationException
$ T.concat [ "Policy disallowed upload of file \""
, fn
, "\" with content-type \""
, ct
, "\"" ] )
handleMultipart ::
(MonadSnap m) =>
UploadPolicy
-> (PartInfo -> Iteratee ByteString IO a)
-> m [a]
handleMultipart uploadPolicy origPartHandler = do
hdrs <- liftM headers getRequest
let (ct, mbBoundary) = getContentType hdrs
tickleTimeout <- getTimeoutAction
let bumpTimeout = tickleTimeout $ uploadTimeout uploadPolicy
let partHandler = if doProcessFormInputs uploadPolicy
then captureVariableOrReadFile
(getMaximumFormInputSize uploadPolicy)
origPartHandler
else (\p -> fmap File (origPartHandler p))
when (ct /= "multipart/form-data") $ do
debug $ "handleMultipart called with content-type=" ++ S.unpack ct
++ ", passing"
pass
when (isNothing mbBoundary) $
throw $ BadPartException $
"got multipart/form-data without boundary"
let boundary = fromJust mbBoundary
captures <- runRequestBody (iter bumpTimeout boundary partHandler)
procCaptures [] captures
where
iter bump boundary ph = iterateeDebugWrapper "killIfTooSlow" $
killIfTooSlow
bump
(minimumUploadRate uploadPolicy)
(minimumUploadSeconds uploadPolicy)
(internalHandleMultipart boundary ph)
ins k v = Map.insertWith' (\a b -> Prelude.head a : b) k [v]
maxFormVars = maximumNumberOfFormInputs uploadPolicy
procCaptures l [] = return $ reverse l
procCaptures l ((File x):xs) = procCaptures (x:l) xs
procCaptures l ((Capture k v):xs) = do
rq <- getRequest
let n = Map.size $ rqParams rq
when (n >= maxFormVars) $
throw $ PolicyViolationException $
T.concat [ "number of form inputs exceeded maximum of "
, T.pack $ show maxFormVars ]
modifyRequest $ rqModifyParams (ins k v)
procCaptures l xs
data PartInfo =
PartInfo { partFieldName :: !ByteString
, partFileName :: !(Maybe ByteString)
, partContentType :: !ByteString
}
deriving (Show)
data FileUploadException =
GenericFileUploadException {
_genericFileUploadExceptionReason :: Text
}
| forall e . (Exception e, Show e) =>
WrappedFileUploadException {
_wrappedFileUploadException :: e
, _wrappedFileUploadExceptionReason :: Text
}
deriving (Typeable)
instance Show FileUploadException where
show (GenericFileUploadException r) = "File upload exception: " ++
T.unpack r
show (WrappedFileUploadException e _) = show e
instance Exception FileUploadException
fileUploadExceptionReason :: FileUploadException -> Text
fileUploadExceptionReason (GenericFileUploadException r) = r
fileUploadExceptionReason (WrappedFileUploadException _ r) = r
uploadExceptionToException :: Exception e => e -> Text -> SomeException
uploadExceptionToException e r =
SomeException $ WrappedFileUploadException e r
uploadExceptionFromException :: Exception e => SomeException -> Maybe e
uploadExceptionFromException x = do
WrappedFileUploadException e _ <- fromException x
cast e
data BadPartException = BadPartException { badPartExceptionReason :: Text }
deriving (Typeable)
instance Exception BadPartException where
toException e@(BadPartException r) = uploadExceptionToException e r
fromException = uploadExceptionFromException
instance Show BadPartException where
show (BadPartException s) = "Bad part: " ++ T.unpack s
data PolicyViolationException = PolicyViolationException {
policyViolationExceptionReason :: Text
} deriving (Typeable)
instance Exception PolicyViolationException where
toException e@(PolicyViolationException r) =
uploadExceptionToException e r
fromException = uploadExceptionFromException
instance Show PolicyViolationException where
show (PolicyViolationException s) = "File upload policy violation: "
++ T.unpack s
data UploadPolicy = UploadPolicy {
processFormInputs :: Bool
, maximumFormInputSize :: Int
, maximumNumberOfFormInputs :: Int
, minimumUploadRate :: Double
, minimumUploadSeconds :: Int
, uploadTimeout :: Int
} deriving (Show, Eq)
defaultUploadPolicy :: UploadPolicy
defaultUploadPolicy = UploadPolicy True maxSize maxNum minRate minSeconds tout
where
maxSize = 2^(17::Int)
maxNum = 10
minRate = 1000
minSeconds = 10
tout = 20
doProcessFormInputs :: UploadPolicy -> Bool
doProcessFormInputs = processFormInputs
setProcessFormInputs :: Bool -> UploadPolicy -> UploadPolicy
setProcessFormInputs b u = u { processFormInputs = b }
getMaximumFormInputSize :: UploadPolicy -> Int
getMaximumFormInputSize = maximumFormInputSize
setMaximumFormInputSize :: Int -> UploadPolicy -> UploadPolicy
setMaximumFormInputSize s u = u { maximumFormInputSize = s }
getMinimumUploadRate :: UploadPolicy -> Double
getMinimumUploadRate = minimumUploadRate
setMinimumUploadRate :: Double -> UploadPolicy -> UploadPolicy
setMinimumUploadRate s u = u { minimumUploadRate = s }
getMinimumUploadSeconds :: UploadPolicy -> Int
getMinimumUploadSeconds = minimumUploadSeconds
setMinimumUploadSeconds :: Int -> UploadPolicy -> UploadPolicy
setMinimumUploadSeconds s u = u { minimumUploadSeconds = s }
getUploadTimeout :: UploadPolicy -> Int
getUploadTimeout = uploadTimeout
setUploadTimeout :: Int -> UploadPolicy -> UploadPolicy
setUploadTimeout s u = u { uploadTimeout = s }
data PartUploadPolicy = PartUploadPolicy {
_maximumFileSize :: Maybe Int64
} deriving (Show, Eq)
disallow :: PartUploadPolicy
disallow = PartUploadPolicy Nothing
allowWithMaximumSize :: Int64 -> PartUploadPolicy
allowWithMaximumSize = PartUploadPolicy . Just
captureVariableOrReadFile ::
Int
-> (PartInfo -> Iteratee ByteString IO a)
-> (PartInfo -> Iteratee ByteString IO (Capture a))
captureVariableOrReadFile maxSize fileHandler partInfo =
case partFileName partInfo of
Nothing -> iter
_ -> liftM File $ fileHandler partInfo
where
iter = varIter `catchError` handler
fieldName = partFieldName partInfo
varIter = do
var <- liftM S.concat $
joinI' $
takeNoMoreThan (fromIntegral maxSize) $$ consume
return $ Capture fieldName var
handler e = do
let m = fromException e :: Maybe TooManyBytesReadException
case m of
Nothing -> throwError e
Just _ -> throwError $ PolicyViolationException $
T.concat [ "form input '"
, TE.decodeUtf8 fieldName
, "' exceeded maximum permissible size ("
, T.pack $ show maxSize
, " bytes)" ]
data Capture a = Capture ByteString ByteString
| File a
deriving (Show)
fileReader :: UploadedFiles
-> FilePath
-> PartInfo
-> Iteratee ByteString IO (PartInfo, FilePath)
fileReader uploadedFiles tmpdir partInfo = do
(fn, h) <- openFileForUpload uploadedFiles tmpdir
let i = iterateeDebugWrapper "fileReader" $ iter fn h
i `catch` \(e::SomeException) -> throwError e
where
iter fileName h = do
iterHandle h
debug "fileReader: closing active file"
closeActiveFile uploadedFiles
return (partInfo, fileName)
internalHandleMultipart ::
ByteString
-> (PartInfo -> Iteratee ByteString IO a)
-> Iteratee ByteString IO [a]
internalHandleMultipart boundary clientHandler = go `catch` errorHandler
where
errorHandler :: SomeException -> Iteratee ByteString IO a
errorHandler e = do
skipToEof
throwError e
go = do
_ <- iterParser $ parseFirstBoundary boundary
step <- iterateeDebugWrapper "boyer-moore" $
(bmhEnumeratee (fullBoundary boundary) $$ processParts iter)
liftM concat $ lift $ run_ $ returnI step
pBoundary b = Atto.try $ do
_ <- string "--"
string b
fullBoundary b = S.concat ["\r\n", "--", b]
pLine = takeWhile (not . isEndOfLine . c2w) <* eol
takeLine = pLine *> pure ()
parseFirstBoundary b = pBoundary b <|> (takeLine *> parseFirstBoundary b)
takeHeaders = hdrs `catchError` handler
where
hdrs = liftM toHeaders $
iterateeDebugWrapper "header parser" $
joinI' $
takeNoMoreThan mAX_HDRS_SIZE $$
iterParser pHeadersWithSeparator
handler e = do
let m = fromException e :: Maybe TooManyBytesReadException
case m of
Nothing -> throwError e
Just _ -> throwError $ BadPartException $
"headers exceeded maximum size"
iter = do
hdrs <- takeHeaders
let (contentType, mboundary) = getContentType hdrs
let (fieldName, fileName) = getFieldName hdrs
if contentType == "multipart/mixed"
then maybe (throwError $ BadPartException $
"got multipart/mixed without boundary")
(processMixed fieldName)
mboundary
else do
let info = PartInfo fieldName fileName contentType
liftM (:[]) $ clientHandler info
processMixed fieldName mixedBoundary = do
_ <- iterParser $ parseFirstBoundary mixedBoundary
step <- iterateeDebugWrapper "boyer-moore" $
(bmhEnumeratee (fullBoundary mixedBoundary) $$
processParts (mixedIter fieldName))
lift $ run_ $ returnI step
mixedIter fieldName = do
hdrs <- takeHeaders
let (contentType, _) = getContentType hdrs
let (_, fileName) = getFieldName hdrs
let info = PartInfo fieldName fileName contentType
clientHandler info
getContentType :: Headers
-> (ByteString, Maybe ByteString)
getContentType hdrs = (contentType, boundary)
where
contentTypeValue = fromMaybe "text/plain" $
getHeader "content-type" hdrs
eCT = fullyParse contentTypeValue pContentTypeWithParameters
(contentType, params) = either (const ("text/plain", [])) id eCT
boundary = findParam "boundary" params
getFieldName :: Headers -> (ByteString, Maybe ByteString)
getFieldName hdrs = (fieldName, fileName)
where
contentDispositionValue = fromMaybe "" $
getHeader "content-disposition" hdrs
eDisposition = fullyParse contentDispositionValue pValueWithParameters
(_, dispositionParameters) =
either (const ("", [])) id eDisposition
fieldName = fromMaybe "" $ findParam "name" dispositionParameters
fileName = findParam "filename" dispositionParameters
findParam :: (Eq a) => a -> [(a, b)] -> Maybe b
findParam p = fmap snd . find ((== p) . fst)
processPart :: (Monad m) => Enumeratee MatchInfo ByteString m a
processPart st =
case st of
(Continue k) -> go k
_ -> yield st (Chunks [])
where
go :: (Monad m) => (Stream ByteString -> Iteratee ByteString m a)
-> Iteratee MatchInfo m (Step ByteString m a)
go !k =
I.head >>= maybe finish process
where
finish =
lift $ runIteratee $ k EOF
process (NoMatch !s) = do
!step <- lift $ runIteratee $ k $ Chunks [s]
case step of
(Continue k') -> go k'
_ -> yield step (Chunks [])
process (Match _) =
lift $ runIteratee $ k EOF
processParts :: Iteratee ByteString IO a
-> Iteratee MatchInfo IO [a]
processParts partIter = iterateeDebugWrapper "processParts" $ go D.empty
where
iter = do
isLast <- bParser
if isLast
then return Nothing
else do
x <- partIter
skipToEof
return $ Just x
go soFar = do
b <- isEOF
if b
then return $ D.toList soFar
else do
innerStep <- processPart $$ iter
output <- lift $ run_ $ returnI innerStep
case output of
Just x -> go (D.append soFar $ D.singleton x)
Nothing -> return $ D.toList soFar
bParser = iterateeDebugWrapper "boundary debugger" $
iterParser $ pBoundaryEnd
pBoundaryEnd = (eol *> pure False) <|> (string "--" *> pure True)
eol :: Parser ByteString
eol = (string "\n") <|> (string "\r\n")
pHeadersWithSeparator :: Parser [(ByteString,ByteString)]
pHeadersWithSeparator = pHeaders <* crlf
toHeaders :: [(ByteString,ByteString)] -> Headers
toHeaders kvps = foldl' f Map.empty kvps'
where
kvps' = map (first CI.mk . second (:[])) kvps
f m (k,v) = Map.insertWith' (flip (++)) k v m
mAX_HDRS_SIZE :: Int64
mAX_HDRS_SIZE = 32768
data UploadedFilesState = UploadedFilesState {
_currentFile :: Maybe (FilePath, Handle)
, _alreadyReadFiles :: [FilePath]
}
emptyUploadedFilesState :: UploadedFilesState
emptyUploadedFilesState = UploadedFilesState Nothing []
newtype UploadedFiles = UploadedFiles (IORef UploadedFilesState)
newUploadedFiles :: MonadIO m => m UploadedFiles
newUploadedFiles = liftM UploadedFiles $
liftIO $ newIORef emptyUploadedFilesState
cleanupUploadedFiles :: (MonadIO m) => UploadedFiles -> m ()
cleanupUploadedFiles (UploadedFiles stateRef) = liftIO $ do
state <- readIORef stateRef
killOpenFile state
mapM_ killFile $ _alreadyReadFiles state
writeIORef stateRef emptyUploadedFilesState
where
killFile = eatException . removeFile
killOpenFile state = maybe (return ())
(\(fp,h) -> do
eatException $ hClose h
eatException $ removeFile fp)
(_currentFile state)
openFileForUpload :: (MonadIO m) =>
UploadedFiles
-> FilePath
-> m (FilePath, Handle)
openFileForUpload ufs@(UploadedFiles stateRef) tmpdir = liftIO $ do
state <- readIORef stateRef
when (isJust $ _currentFile state) $ do
cleanupUploadedFiles ufs
throw $ GenericFileUploadException alreadyOpenMsg
fph@(_,h) <- openBinaryTempFile tmpdir "snap-"
hSetBuffering h NoBuffering
writeIORef stateRef $ state { _currentFile = Just fph }
return fph
where
alreadyOpenMsg =
T.concat [ "Internal error! UploadedFiles: "
, "opened new file with pre-existing open handle" ]
closeActiveFile :: (MonadIO m) => UploadedFiles -> m ()
closeActiveFile (UploadedFiles stateRef) = liftIO $ do
state <- readIORef stateRef
let m = _currentFile state
maybe (return ())
(\(fp,h) -> do
eatException $ hClose h
writeIORef stateRef $
state { _currentFile = Nothing
, _alreadyReadFiles = fp:(_alreadyReadFiles state) })
m
eatException :: (MonadCatchIO m) => m a -> m ()
eatException m =
(m >> return ()) `catch` (\(_ :: SomeException) -> return ())