{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase        #-}
-- |
-- Wrapper functions of 'Network.HTTP.Simple' and 'Network.HTTP.Client' to
-- add the 'User-Agent' HTTP request header to each request.

module Network.HTTP.StackClient
  ( httpJSON
  , httpLbs
  , httpNoBody
  , httpSink
  , withResponse
  , setRequestMethod
  , setRequestHeader
  , addRequestHeader
  , setRequestBody
  , getResponseHeaders
  , getResponseBody
  , getResponseStatusCode
  , parseRequest
  , getUri
  , path
  , checkResponse
  , parseUrlThrow
  , requestHeaders
  , getGlobalManager
  , applyDigestAuth
  , displayDigestAuthException
  , Request
  , RequestBody(RequestBodyBS, RequestBodyLBS)
  , Response
  , HttpException
  , hAccept
  , hContentLength
  , hContentMD5
  , methodPut
  , formDataBody
  , partFileRequestBody
  , partBS
  , partLBS
  , setGithubHeaders
  , download
  , redownload
  , verifiedDownload
  , verifiedDownloadWithProgress
  , CheckHexDigest (..)
  , DownloadRequest
  , drRetryPolicyDefault
  , VerifiedDownloadException (..)
  , HashCheck (..)
  , mkDownloadRequest
  , setHashChecks
  , setLengthCheck
  , setRetryPolicy
  , setForceDownload
  ) where

import           Control.Monad.State (get, put, modify)
import           Data.Aeson (FromJSON)
import qualified Data.ByteString as Strict
import           Data.Conduit (ConduitM, ConduitT, awaitForever, (.|), yield, await)
import           Data.Conduit.Lift (evalStateC)
import qualified Data.Conduit.List as CL
import           Data.Monoid (Sum (..))
import qualified Data.Text as T
import           Data.Time.Clock (NominalDiffTime, diffUTCTime, getCurrentTime)
import           Network.HTTP.Client (Request, RequestBody(..), Response, parseRequest, getUri, path, checkResponse, parseUrlThrow)
import           Network.HTTP.Simple (setRequestMethod, setRequestBody, setRequestHeader, addRequestHeader, HttpException(..), getResponseBody, getResponseStatusCode, getResponseHeaders)
import           Network.HTTP.Types (hAccept, hContentLength, hContentMD5, methodPut)
import           Network.HTTP.Conduit (requestHeaders)
import           Network.HTTP.Client.TLS (getGlobalManager, applyDigestAuth, displayDigestAuthException)
import           Network.HTTP.Download hiding (download, redownload, verifiedDownload)
import qualified Network.HTTP.Download as Download
import qualified Network.HTTP.Simple
import           Network.HTTP.Client.MultipartFormData (formDataBody, partFileRequestBody, partBS, partLBS)
import           Path
import           Prelude (until, (!!))
import           RIO
import           RIO.PrettyPrint
import           Text.Printf (printf)


setUserAgent :: Request -> Request
setUserAgent :: Request -> Request
setUserAgent = HeaderName -> [ByteString] -> Request -> Request
setRequestHeader HeaderName
"User-Agent" [ByteString
"The Haskell Stack"]


httpJSON :: (MonadIO m, FromJSON a) => Request -> m (Response a)
httpJSON :: Request -> m (Response a)
httpJSON = Request -> m (Response a)
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
Request -> m (Response a)
Network.HTTP.Simple.httpJSON (Request -> m (Response a))
-> (Request -> Request) -> Request -> m (Response a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpLbs :: MonadIO m => Request -> m (Response LByteString)
httpLbs :: Request -> m (Response LByteString)
httpLbs = Request -> m (Response LByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response LByteString)
Network.HTTP.Simple.httpLbs (Request -> m (Response LByteString))
-> (Request -> Request) -> Request -> m (Response LByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpNoBody :: MonadIO m => Request -> m (Response ())
httpNoBody :: Request -> m (Response ())
httpNoBody = Request -> m (Response ())
forall (m :: * -> *). MonadIO m => Request -> m (Response ())
Network.HTTP.Simple.httpNoBody (Request -> m (Response ()))
-> (Request -> Request) -> Request -> m (Response ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpSink
  :: MonadUnliftIO m
  => Request
  -> (Response () -> ConduitM Strict.ByteString Void m a)
  -> m a
httpSink :: Request -> (Response () -> ConduitM ByteString Void m a) -> m a
httpSink = Request -> (Response () -> ConduitM ByteString Void m a) -> m a
forall (m :: * -> *) a.
MonadUnliftIO m =>
Request -> (Response () -> ConduitM ByteString Void m a) -> m a
Network.HTTP.Simple.httpSink (Request -> (Response () -> ConduitM ByteString Void m a) -> m a)
-> (Request -> Request)
-> Request
-> (Response () -> ConduitM ByteString Void m a)
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


withResponse
  :: (MonadUnliftIO m, MonadIO n)
  => Request -> (Response (ConduitM i Strict.ByteString n ()) -> m a) -> m a
withResponse :: Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a
withResponse = Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a
forall (m :: * -> *) (n :: * -> *) i a.
(MonadUnliftIO m, MonadIO n) =>
Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a
Network.HTTP.Simple.withResponse (Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a)
-> (Request -> Request)
-> Request
-> (Response (ConduitM i ByteString n ()) -> m a)
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent

-- | Set the user-agent request header
setGithubHeaders :: Request -> Request
setGithubHeaders :: Request -> Request
setGithubHeaders = HeaderName -> [ByteString] -> Request -> Request
setRequestHeader HeaderName
"Accept" [ByteString
"application/vnd.github.v3+json"]

-- | Download the given URL to the given location. If the file already exists,
-- no download is performed. Otherwise, creates the parent directory, downloads
-- to a temporary file, and on file download completion moves to the
-- appropriate destination.
--
-- Throws an exception if things go wrong
download :: HasTerm env
         => Request
         -> Path Abs File -- ^ destination
         -> RIO env Bool -- ^ Was a downloaded performed (True) or did the file already exist (False)?
download :: Request -> Path Abs File -> RIO env Bool
download Request
req Path Abs File
dest = Request -> Path Abs File -> RIO env Bool
forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
Download.download (Request -> Request
setUserAgent Request
req) Path Abs File
dest

-- | Same as 'download', but will download a file a second time if it is already present.
--
-- Returns 'True' if the file was downloaded, 'False' otherwise
redownload :: HasTerm env
           => Request
           -> Path Abs File -- ^ destination
           -> RIO env Bool
redownload :: Request -> Path Abs File -> RIO env Bool
redownload Request
req Path Abs File
dest = Request -> Path Abs File -> RIO env Bool
forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
Download.redownload (Request -> Request
setUserAgent Request
req) Path Abs File
dest

-- | Copied and extended version of Network.HTTP.Download.download.
--
-- Has the following additional features:
-- * Verifies that response content-length header (if present)
--     matches expected length
-- * Limits the download to (close to) the expected # of bytes
-- * Verifies that the expected # bytes were downloaded (not too few)
-- * Verifies md5 if response includes content-md5 header
-- * Verifies the expected hashes
--
-- Throws VerifiedDownloadException.
-- Throws IOExceptions related to file system operations.
-- Throws HttpException.
verifiedDownload
         :: HasTerm env
         => DownloadRequest
         -> Path Abs File -- ^ destination
         -> (Maybe Integer -> ConduitM ByteString Void (RIO env) ()) -- ^ custom hook to observe progress
         -> RIO env Bool -- ^ Whether a download was performed
verifiedDownload :: DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
verifiedDownload DownloadRequest
dr Path Abs File
destpath Maybe Integer -> ConduitM ByteString Void (RIO env) ()
progressSink =
    DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
Download.verifiedDownload DownloadRequest
dr' Path Abs File
destpath Maybe Integer -> ConduitM ByteString Void (RIO env) ()
progressSink
  where
    dr' :: DownloadRequest
dr' = (Request -> Request) -> DownloadRequest -> DownloadRequest
modifyRequest Request -> Request
setUserAgent DownloadRequest
dr

verifiedDownloadWithProgress
  :: HasTerm env
  => DownloadRequest
  -> Path Abs File
  -> Text
  -> Maybe Int
  -> RIO env Bool
verifiedDownloadWithProgress :: DownloadRequest
-> Path Abs File -> Text -> Maybe Int -> RIO env Bool
verifiedDownloadWithProgress DownloadRequest
req Path Abs File
destpath Text
lbl Maybe Int
msize =
  DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
verifiedDownload DownloadRequest
req Path Abs File
destpath (Text
-> Maybe Int
-> Maybe Integer
-> ConduitM ByteString Void (RIO env) ()
forall env (m :: * -> *) f c.
(HasLogFunc env, MonadIO m, MonadReader env m) =>
Text -> Maybe Int -> f -> ConduitT ByteString c m ()
chattyDownloadProgress Text
lbl Maybe Int
msize)

chattyDownloadProgress
  :: ( HasLogFunc env
     , MonadIO m
     , MonadReader env m
     )
  => Text
  -> Maybe Int
  -> f
  -> ConduitT ByteString c m ()
chattyDownloadProgress :: Text -> Maybe Int -> f -> ConduitT ByteString c m ()
chattyDownloadProgress Text
label Maybe Int
mtotalSize f
_ = do
    ()
_ <- Utf8Builder -> ConduitT ByteString c m ()
forall (m :: * -> *) env.
(MonadIO m, HasCallStack, MonadReader env m, HasLogFunc env) =>
Utf8Builder -> m ()
logSticky (Utf8Builder -> ConduitT ByteString c m ())
-> Utf8Builder -> ConduitT ByteString c m ()
forall a b. (a -> b) -> a -> b
$ Text -> Utf8Builder
forall a. Display a => a -> Utf8Builder
RIO.display Text
label Utf8Builder -> Utf8Builder -> Utf8Builder
forall a. Semigroup a => a -> a -> a
<> Utf8Builder
": download has begun"
    (ByteString -> Sum Int) -> ConduitT ByteString (Sum Int) m ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
CL.map (Int -> Sum Int
forall a. a -> Sum a
Sum (Int -> Sum Int) -> (ByteString -> Int) -> ByteString -> Sum Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
Strict.length)
      ConduitT ByteString (Sum Int) m ()
-> ConduitM (Sum Int) c m () -> ConduitT ByteString c m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| NominalDiffTime -> ConduitM (Sum Int) (Sum Int) m ()
forall a (m :: * -> *).
(Monoid a, Semigroup a, MonadIO m) =>
NominalDiffTime -> ConduitM a a m ()
chunksOverTime NominalDiffTime
1
      ConduitM (Sum Int) (Sum Int) m ()
-> ConduitM (Sum Int) c m () -> ConduitM (Sum Int) c m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ConduitM (Sum Int) c m ()
forall o. ConduitT (Sum Int) o m ()
go
  where
    go :: ConduitT (Sum Int) o m ()
go = Int
-> ConduitT (Sum Int) o (StateT Int m) ()
-> ConduitT (Sum Int) o m ()
forall (m :: * -> *) s i o r.
Monad m =>
s -> ConduitT i o (StateT s m) r -> ConduitT i o m r
evalStateC Int
0 (ConduitT (Sum Int) o (StateT Int m) ()
 -> ConduitT (Sum Int) o m ())
-> ConduitT (Sum Int) o (StateT Int m) ()
-> ConduitT (Sum Int) o m ()
forall a b. (a -> b) -> a -> b
$ (Sum Int -> ConduitT (Sum Int) o (StateT Int m) ())
-> ConduitT (Sum Int) o (StateT Int m) ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever ((Sum Int -> ConduitT (Sum Int) o (StateT Int m) ())
 -> ConduitT (Sum Int) o (StateT Int m) ())
-> (Sum Int -> ConduitT (Sum Int) o (StateT Int m) ())
-> ConduitT (Sum Int) o (StateT Int m) ()
forall a b. (a -> b) -> a -> b
$ \(Sum Int
size) -> do
        (Int -> Int) -> ConduitT (Sum Int) o (StateT Int m) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
size)
        Int
totalSoFar <- ConduitT (Sum Int) o (StateT Int m) Int
forall s (m :: * -> *). MonadState s m => m s
get
        Utf8Builder -> ConduitT (Sum Int) o (StateT Int m) ()
forall (m :: * -> *) env.
(MonadIO m, HasCallStack, MonadReader env m, HasLogFunc env) =>
Utf8Builder -> m ()
logSticky (Utf8Builder -> ConduitT (Sum Int) o (StateT Int m) ())
-> Utf8Builder -> ConduitT (Sum Int) o (StateT Int m) ()
forall a b. (a -> b) -> a -> b
$ String -> Utf8Builder
forall a. IsString a => String -> a
fromString (String -> Utf8Builder) -> String -> Utf8Builder
forall a b. (a -> b) -> a -> b
$
            case Maybe Int
mtotalSize of
                Maybe Int
Nothing -> Int -> String
forall t a. (PrintfType t, Integral a) => a -> t
chattyProgressNoTotal Int
totalSoFar
                Just Int
0 -> Int -> String
forall t a. (PrintfType t, Integral a) => a -> t
chattyProgressNoTotal Int
totalSoFar
                Just Int
totalSize -> Int -> Int -> String
forall a a t. (Integral a, Integral a, PrintfType t) => a -> a -> t
chattyProgressWithTotal Int
totalSoFar Int
totalSize

    -- Example: ghc: 42.13 KiB downloaded...
    chattyProgressNoTotal :: a -> t
chattyProgressNoTotal a
totalSoFar =
        String -> String -> t
forall r. PrintfType r => String -> r
printf (String
"%s: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> a -> String
forall a. Integral a => String -> a -> String
bytesfmt String
"%7.2f" a
totalSoFar String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" downloaded...")
                (Text -> String
T.unpack Text
label)

    -- Example: ghc: 50.00 MiB / 100.00 MiB (50.00%) downloaded...
    chattyProgressWithTotal :: a -> a -> t
chattyProgressWithTotal a
totalSoFar a
total =
      String -> String -> Double -> t
forall r. PrintfType r => String -> r
printf (String
"%s: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
              String -> a -> String
forall a. Integral a => String -> a -> String
bytesfmt String
"%7.2f" a
totalSoFar String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" / " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
              String -> a -> String
forall a. Integral a => String -> a -> String
bytesfmt String
"%.2f" a
total String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
              String
" (%6.2f%%) downloaded...")
              (Text -> String
T.unpack Text
label)
              Double
percentage
      where percentage :: Double
            percentage :: Double
percentage = a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
totalSoFar Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
total Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
100

-- | Given a printf format string for the decimal part and a number of
-- bytes, formats the bytes using an appropriate unit and returns the
-- formatted string.
--
-- >>> bytesfmt "%.2" 512368
-- "500.359375 KiB"
bytesfmt :: Integral a => String -> a -> String
bytesfmt :: String -> a -> String
bytesfmt String
formatter a
bs = String -> Double -> String -> String
forall r. PrintfType r => String -> r
printf (String
formatter String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" %s")
                               (a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a
forall a. Num a => a -> a
signum a
bs) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dec :: Double)
                               ([String]
bytesSuffixes [String] -> Int -> String
forall a. [a] -> Int -> a
!! Int
i)
  where
    (Double
dec,Int
i) = a -> (Double, Int)
forall a a. (Fractional a, Integral a, Ord a) => a -> (a, Int)
getSuffix (a -> a
forall a. Num a => a -> a
abs a
bs)
    getSuffix :: a -> (a, Int)
getSuffix a
n = ((a, Int) -> Bool)
-> ((a, Int) -> (a, Int)) -> (a, Int) -> (a, Int)
forall a. (a -> Bool) -> (a -> a) -> a -> a
until (a, Int) -> Bool
forall a. (Ord a, Num a) => (a, Int) -> Bool
p (\(a
x,Int
y) -> (a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
1024, Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) (a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n,Int
0)
      where p :: (a, Int) -> Bool
p (a
n',Int
numDivs) = a
n' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1024 Bool -> Bool -> Bool
|| Int
numDivs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ([String] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
bytesSuffixes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    bytesSuffixes :: [String]
    bytesSuffixes :: [String]
bytesSuffixes = [String
"B",String
"KiB",String
"MiB",String
"GiB",String
"TiB",String
"PiB",String
"EiB",String
"ZiB",String
"YiB"]

-- Await eagerly (collect with monoidal append),
-- but space out yields by at least the given amount of time.
-- The final yield may come sooner, and may be a superfluous mempty.
-- Note that Integer and Float literals can be turned into NominalDiffTime
-- (these literals are interpreted as "seconds")
chunksOverTime :: (Monoid a, Semigroup a, MonadIO m) => NominalDiffTime -> ConduitM a a m ()
chunksOverTime :: NominalDiffTime -> ConduitM a a m ()
chunksOverTime NominalDiffTime
diff = do
    UTCTime
currentTime <- IO UTCTime -> ConduitT a a m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    (UTCTime, a)
-> ConduitT a a (StateT (UTCTime, a) m) () -> ConduitM a a m ()
forall (m :: * -> *) s i o r.
Monad m =>
s -> ConduitT i o (StateT s m) r -> ConduitT i o m r
evalStateC (UTCTime
currentTime, a
forall a. Monoid a => a
mempty) ConduitT a a (StateT (UTCTime, a) m) ()
go
  where
    -- State is a tuple of:
    -- * the last time a yield happened (or the beginning of the sink)
    -- * the accumulated awaits since the last yield
    go :: ConduitT a a (StateT (UTCTime, a) m) ()
go = ConduitT a a (StateT (UTCTime, a) m) (Maybe a)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await ConduitT a a (StateT (UTCTime, a) m) (Maybe a)
-> (Maybe a -> ConduitT a a (StateT (UTCTime, a) m) ())
-> ConduitT a a (StateT (UTCTime, a) m) ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe a
Nothing -> do
        (UTCTime
_, a
acc) <- ConduitT a a (StateT (UTCTime, a) m) (UTCTime, a)
forall s (m :: * -> *). MonadState s m => m s
get
        a -> ConduitT a a (StateT (UTCTime, a) m) ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield a
acc
      Just a
a -> do
        (UTCTime
lastTime, a
acc) <- ConduitT a a (StateT (UTCTime, a) m) (UTCTime, a)
forall s (m :: * -> *). MonadState s m => m s
get
        let acc' :: a
acc' = a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
a
        UTCTime
currentTime <- IO UTCTime -> ConduitT a a (StateT (UTCTime, a) m) UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
        if NominalDiffTime
diff NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
< UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
currentTime UTCTime
lastTime
          then (UTCTime, a) -> ConduitT a a (StateT (UTCTime, a) m) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (UTCTime
currentTime, a
forall a. Monoid a => a
mempty) ConduitT a a (StateT (UTCTime, a) m) ()
-> ConduitT a a (StateT (UTCTime, a) m) ()
-> ConduitT a a (StateT (UTCTime, a) m) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> ConduitT a a (StateT (UTCTime, a) m) ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield a
acc'
          else (UTCTime, a) -> ConduitT a a (StateT (UTCTime, a) m) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (UTCTime
lastTime,    a
acc')
        ConduitT a a (StateT (UTCTime, a) m) ()
go