module Network.HTTP.Download.Verified
( verifiedDownload
, DownloadRequest(..)
, HashCheck(..)
, CheckHexDigest(..)
, LengthCheck
, VerifiedDownloadException(..)
) where
import qualified Data.List as List
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Base64 as B64
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.List as CL
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Control.Monad
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Reader
import Control.Applicative
import Crypto.Hash
import Crypto.Hash.Conduit (sinkHash)
import Data.ByteString (ByteString)
import Data.Conduit
import Data.Conduit.Binary (sourceHandle, sinkHandle)
import Data.Foldable (traverse_)
import Data.Monoid
import Data.String
import Data.Typeable (Typeable)
import Network.HTTP.Client.Conduit
import Network.HTTP.Types.Header (hContentLength, hContentMD5)
import Path
import Prelude
import System.FilePath((<.>))
import System.Directory
import System.IO
data DownloadRequest = DownloadRequest
{ drRequest :: Request
, drHashChecks :: [HashCheck]
, drLengthCheck :: Maybe LengthCheck
}
deriving Show
data HashCheck = forall a. (Show a, HashAlgorithm a) => HashCheck
{ hashCheckAlgorithm :: a
, hashCheckHexDigest :: CheckHexDigest
}
deriving instance Show HashCheck
data CheckHexDigest
= CheckHexDigestString String
| CheckHexDigestByteString ByteString
| CheckHexDigestHeader ByteString
deriving Show
instance IsString CheckHexDigest where
fromString = CheckHexDigestString
type LengthCheck = Int
data VerifiedDownloadException
= WrongContentLength
Request
Int
ByteString
| WrongStreamLength
Request
Int
Int
| WrongDigest
Request
String
CheckHexDigest
String
deriving (Typeable)
instance Show VerifiedDownloadException where
show (WrongContentLength req expected actual) =
"Download expectation failure: ContentLength header\n"
++ "Expected: " ++ show expected ++ "\n"
++ "Actual: " ++ displayByteString actual ++ "\n"
++ "For: " ++ show (getUri req)
show (WrongStreamLength req expected actual) =
"Download expectation failure: download size\n"
++ "Expected: " ++ show expected ++ "\n"
++ "Actual: " ++ show actual ++ "\n"
++ "For: " ++ show (getUri req)
show (WrongDigest req algo expected actual) =
"Download expectation failure: content hash (" ++ algo ++ ")\n"
++ "Expected: " ++ displayCheckHexDigest expected ++ "\n"
++ "Actual: " ++ actual ++ "\n"
++ "For: " ++ show (getUri req)
instance Exception VerifiedDownloadException
data VerifyFileException
= WrongFileSize
Int
Integer
deriving (Show, Typeable)
instance Exception VerifyFileException
displayByteString :: ByteString -> String
displayByteString =
Text.unpack . Text.strip . Text.decodeUtf8
displayCheckHexDigest :: CheckHexDigest -> String
displayCheckHexDigest (CheckHexDigestString s) = s ++ " (String)"
displayCheckHexDigest (CheckHexDigestByteString s) = displayByteString s ++ " (ByteString)"
displayCheckHexDigest (CheckHexDigestHeader h) =
displayByteString (B64.decodeLenient h) ++ " (Header. unencoded: "
++ displayByteString h ++ ")"
sinkCheckHash :: MonadThrow m
=> Request
-> HashCheck
-> Consumer ByteString m ()
sinkCheckHash req HashCheck{..} = do
digest <- sinkHashUsing hashCheckAlgorithm
let actualDigestString = show digest
let actualDigestHexByteString = digestToHexByteString digest
let passedCheck = case hashCheckHexDigest of
CheckHexDigestString s -> s == actualDigestString
CheckHexDigestByteString b -> b == actualDigestHexByteString
CheckHexDigestHeader b -> B64.decodeLenient b == actualDigestHexByteString
|| b == actualDigestHexByteString
when (not passedCheck) $
throwM $ WrongDigest req (show hashCheckAlgorithm) hashCheckHexDigest actualDigestString
assertLengthSink :: MonadThrow m
=> Request
-> LengthCheck
-> ZipSink ByteString m ()
assertLengthSink req expectedStreamLength = ZipSink $ do
Sum actualStreamLength <- CL.foldMap (Sum . ByteString.length)
when (actualStreamLength /= expectedStreamLength) $
throwM $ WrongStreamLength req expectedStreamLength actualStreamLength
sinkHashUsing :: (Monad m, HashAlgorithm a) => a -> Consumer ByteString m (Digest a)
sinkHashUsing _ = sinkHash
hashChecksToZipSink :: MonadThrow m => Request -> [HashCheck] -> ZipSink ByteString m ()
hashChecksToZipSink req = traverse_ (ZipSink . sinkCheckHash req)
verifiedDownload :: (MonadReader env m, HasHttpManager env, MonadIO m)
=> DownloadRequest
-> Path Abs File
-> Sink ByteString (ReaderT env IO) ()
-> m Bool
verifiedDownload DownloadRequest{..} destpath progressSink = do
let req = drRequest
env <- ask
liftIO $ whenM' getShouldDownload $ do
createDirectoryIfMissing True dir
withBinaryFile fptmp WriteMode $ \h ->
flip runReaderT env $
withResponse req (go h)
renameFile fptmp fp
where
whenM' mp m = do
p <- mp
if p then m >> return True else return False
fp = toFilePath destpath
fptmp = fp <.> "tmp"
dir = toFilePath $ parent destpath
getShouldDownload = do
fileExists <- doesFileExist fp
if fileExists
then not <$> fileMatchesExpectations
else return True
fileMatchesExpectations =
(checkExpectations >> return True)
`catch` \(_ :: VerifyFileException) -> return False
`catch` \(_ :: VerifiedDownloadException) -> return False
whenJust :: Monad m => Maybe a -> (a -> m ()) -> m ()
whenJust (Just a) f = f a
whenJust _ _ = return ()
checkExpectations = bracket (openFile fp ReadMode) hClose $ \h -> do
whenJust drLengthCheck $ checkFileSizeExpectations h
sourceHandle h $$ getZipSink (hashChecksToZipSink drRequest drHashChecks)
checkFileSizeExpectations h expectedFileSize = do
fileSizeInteger <- hFileSize h
when (fileSizeInteger > toInteger (maxBound :: Int)) $
throwM $ WrongFileSize expectedFileSize fileSizeInteger
let fileSize = fromInteger fileSizeInteger
when (fileSize /= expectedFileSize) $
throwM $ WrongFileSize expectedFileSize fileSizeInteger
checkContentLengthHeader headers expectedContentLength = do
case List.lookup hContentLength headers of
Just lengthBS -> do
let lengthStr = displayByteString lengthBS
when (lengthStr /= show expectedContentLength) $
throwM $ WrongContentLength drRequest expectedContentLength lengthBS
_ -> return ()
go h res = do
let headers = responseHeaders res
whenJust drLengthCheck $ checkContentLengthHeader headers
let hashChecks = (case List.lookup hContentMD5 headers of
Just md5BS ->
[ HashCheck
{ hashCheckAlgorithm = MD5
, hashCheckHexDigest = CheckHexDigestHeader md5BS
}
]
Nothing -> []
) ++ drHashChecks
responseBody res
$= maybe (awaitForever yield) CB.isolate drLengthCheck
$$ getZipSink
( hashChecksToZipSink drRequest hashChecks
*> maybe (pure ()) (assertLengthSink drRequest) drLengthCheck
*> ZipSink (sinkHandle h)
*> ZipSink progressSink)