{-|

Copyright:

  This file is part of the package openid-connect.  It is subject to
  the license terms in the LICENSE file found in the top-level
  directory of this distribution and at:

    https://code.devalot.com/open/openid-connect

  No part of this package, including this file, may be copied,
  modified, propagated, or distributed except according to the terms
  contained in the LICENSE file.

License: BSD-2-Clause

Helpers for HTTPS.

-}
module OpenID.Connect.Client.HTTP
  ( HTTPS
  , uriToText
  , forceHTTPS
  , requestFromURI
  , addRequestHeader
  , jsonPostRequest
  , cacheUntil
  , parseResponse
  ) where

--------------------------------------------------------------------------------
-- Imports:
import Control.Applicative
import Data.Aeson (ToJSON, FromJSON, eitherDecode)
import qualified Data.Aeson as Aeson
import Data.Bifunctor (bimap)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy as LByteString
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.CaseInsensitive (CI)
import Data.Char (isDigit)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import Data.Time.Format (parseTimeM, defaultTimeLocale)
import qualified Network.HTTP.Client as HTTP
import qualified Network.HTTP.Types.Header as HTTP
import qualified Network.HTTP.Types.Status as HTTP
import Network.URI (URI(..), parseURI, uriToString)
import OpenID.Connect.JSON (ErrorResponse(..))

--------------------------------------------------------------------------------
-- | A function that can make HTTPS requests.
--
-- Make sure you are using a @Manager@ value from the
-- @http-client-tls@ package.  It's imperative that the requests
-- flowing through this function are encrypted.
--
-- All requests are set to throw an exception if the response status
-- code is not in the 2xx range.  Therefore, functions that take this
-- 'HTTPS' type should be called in an exception-safe way and any
-- exception should be treated as an authentication failure.
--
-- @since 0.1.0.0
type HTTPS m = HTTP.Request -> m (HTTP.Response LByteString.ByteString)

--------------------------------------------------------------------------------
-- | Helper for rendering a URI as Text.
uriToText :: URI -> Text
uriToText :: URI -> Text
uriToText URI
uri = String -> Text
Text.pack ((String -> String) -> URI -> String -> String
uriToString forall a. a -> a
id URI
uri [])

--------------------------------------------------------------------------------
-- | Force the given URI to use HTTPS.
forceHTTPS :: URI -> URI
forceHTTPS :: URI -> URI
forceHTTPS URI
uri = URI
uri { uriScheme :: String
uriScheme = String
"https:" }

--------------------------------------------------------------------------------
-- | Convert a URI or Text value into a pre-configured request object.
requestFromURI :: Either Text URI -> Maybe HTTP.Request
requestFromURI :: Either Text URI -> Maybe Request
requestFromURI (Left Text
t) = String -> Maybe URI
parseURI (Text -> String
Text.unpack Text
t) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either Text URI -> Maybe Request
requestFromURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. b -> Either a b
Right
requestFromURI (Right URI
uri) =
  forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI (URI -> URI
forceHTTPS URI
uri)
    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString
"Accept", ByteString
"application/json")

--------------------------------------------------------------------------------
-- | Add a JSON body to a request.
jsonPostRequest :: ToJSON a => a -> HTTP.Request -> HTTP.Request
jsonPostRequest :: forall a. ToJSON a => a -> Request -> Request
jsonPostRequest a
json Request
req = (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString
"Content-Type", ByteString
"application/json") forall a b. (a -> b) -> a -> b
$
  Request
req { method :: ByteString
HTTP.method = ByteString
"POST"
      , requestBody :: RequestBody
HTTP.requestBody = ByteString -> RequestBody
HTTP.RequestBodyLBS (forall a. ToJSON a => a -> ByteString
Aeson.encode a
json)
      }

--------------------------------------------------------------------------------
-- | Add a header to the request.
addRequestHeader :: (CI ByteString, ByteString) -> HTTP.Request -> HTTP.Request
addRequestHeader :: (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString, ByteString)
header Request
req =
  Request
req { requestHeaders :: RequestHeaders
HTTP.requestHeaders =
          (CI ByteString, ByteString)
header forall a. a -> [a] -> [a]
: forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= forall a b. (a, b) -> a
fst (CI ByteString, ByteString)
header) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (Request -> RequestHeaders
HTTP.requestHeaders Request
req)
      }

--------------------------------------------------------------------------------
-- | Given a response, calculate how long it can be cached.
cacheUntil :: HTTP.Response a -> Maybe UTCTime
cacheUntil :: forall a. Response a -> Maybe UTCTime
cacheUntil Response a
res = Maybe UTCTime
maxAge forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe UTCTime
expires
  where
    parseTime :: ByteString -> Maybe UTCTime
    parseTime :: ByteString -> Maybe UTCTime
parseTime = forall (m :: * -> *) t.
(MonadFail m, ParseTime t) =>
Bool -> TimeLocale -> String -> String -> m t
parseTimeM Bool
True TimeLocale
defaultTimeLocale String
rfc1123 forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
Char8.unpack

    rfc1123 :: String
    rfc1123 :: String
rfc1123 = String
"%a, %d %b %Y %X %Z"

    date :: Maybe UTCTime
    date :: Maybe UTCTime
date = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hDate (forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe UTCTime
parseTime

    expires :: Maybe UTCTime
    expires :: Maybe UTCTime
expires = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hExpires (forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe UTCTime
parseTime

    maxAge :: Maybe UTCTime
    maxAge :: Maybe UTCTime
maxAge = do
      UTCTime
dt <- Maybe UTCTime
date
      ByteString
bs <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hCacheControl (forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res)
      ByteString
ma <- ByteString -> Maybe ByteString
nullM (forall a b. (a, b) -> b
snd (ByteString -> ByteString -> (ByteString, ByteString)
Char8.breakSubstring ByteString
"max-age" ByteString
bs))
      ByteString
bn <- ByteString -> Maybe ByteString
nullM (forall a b. (a, b) -> b
snd ((Char -> Bool) -> ByteString -> (ByteString, ByteString)
Char8.break Char -> Bool
isDigit ByteString
ma))
      NominalDiffTime -> UTCTime -> UTCTime
addUTCTime forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe (Int, ByteString)
Char8.readInt (Int -> ByteString -> ByteString
Char8.take Int
6 ByteString
bn) -- Limit input to readInt
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure UTCTime
dt

    nullM :: ByteString -> Maybe ByteString
    nullM :: ByteString -> Maybe ByteString
nullM ByteString
bs = if ByteString -> Bool
Char8.null ByteString
bs then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just ByteString
bs

--------------------------------------------------------------------------------
-- | Decode the JSON body of a request and calculate how long it can
-- be cached.
parseResponse
  :: FromJSON a
  => HTTP.Response LByteString.ByteString
  -> Either ErrorResponse (a, Maybe UTCTime)
parseResponse :: forall a.
FromJSON a =>
Response ByteString -> Either ErrorResponse (a, Maybe UTCTime)
parseResponse Response ByteString
response =
  if Status -> Bool
HTTP.statusIsSuccessful (forall body. Response body -> Status
HTTP.responseStatus Response ByteString
response)
    then forall a. FromJSON a => ByteString -> Either String a
eitherDecode (forall body. Response body -> body
HTTP.responseBody Response ByteString
response) forall a b. a -> (a -> b) -> b
&
         forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap String -> ErrorResponse
asError (,forall a. Response a -> Maybe UTCTime
cacheUntil Response ByteString
response)
    else forall a b. a -> Either a b
Left (String -> ErrorResponse
asError String
"invalid response from server")
  where
    asError :: String -> ErrorResponse
    asError :: String -> ErrorResponse
asError String
s = case forall a. FromJSON a => ByteString -> Either String a
eitherDecode (forall body. Response body -> body
HTTP.responseBody Response ByteString
response) of
      Left String
_  -> Text -> Maybe Text -> ErrorResponse
ErrorResponse (String -> Text
Text.pack String
s) (forall a. a -> Maybe a
Just Text
bodyError)
      Right ErrorResponse
e -> ErrorResponse
e

    bodyError :: Text
    bodyError :: Text
bodyError = Response ByteString
response
              forall a b. a -> (a -> b) -> b
& forall body. Response body -> body
HTTP.responseBody
              forall a b. a -> (a -> b) -> b
& Int64 -> ByteString -> ByteString
LChar8.take Int64
1024
              forall a b. a -> (a -> b) -> b
& ByteString -> ByteString
LChar8.toStrict
              forall a b. a -> (a -> b) -> b
& ByteString -> Text
Text.decodeUtf8