{-|

Copyright:

  This file is part of the package openid-connect.  It is subject to
  the license terms in the LICENSE file found in the top-level
  directory of this distribution and at:

    https://code.devalot.com/open/openid-connect

  No part of this package, including this file, may be copied,
  modified, propagated, or distributed except according to the terms
  contained in the LICENSE file.

License: BSD-2-Clause

Helpers for HTTPS.

-}
module OpenID.Connect.Client.HTTP
  ( HTTPS
  , uriToText
  , forceHTTPS
  , requestFromURI
  , addRequestHeader
  , jsonPostRequest
  , cacheUntil
  , parseResponse
  ) where

--------------------------------------------------------------------------------
-- Imports:
import Control.Applicative
import Data.Aeson (ToJSON, FromJSON, eitherDecode)
import qualified Data.Aeson as Aeson
import Data.Bifunctor (bimap)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy as LByteString
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.CaseInsensitive (CI)
import Data.Char (isDigit)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import Data.Time.Format (parseTimeM, defaultTimeLocale)
import qualified Network.HTTP.Client as HTTP
import qualified Network.HTTP.Types.Header as HTTP
import qualified Network.HTTP.Types.Status as HTTP
import Network.URI (URI(..), parseURI, uriToString)
import OpenID.Connect.JSON (ErrorResponse(..))

--------------------------------------------------------------------------------
-- | A function that can make HTTPS requests.
--
-- Make sure you are using a @Manager@ value from the
-- @http-client-tls@ package.  It's imperative that the requests
-- flowing through this function are encrypted.
--
-- All requests are set to throw an exception if the response status
-- code is not in the 2xx range.  Therefore, functions that take this
-- 'HTTPS' type should be called in an exception-safe way and any
-- exception should be treated as an authentication failure.
--
-- @since 0.1.0.0
type HTTPS m = HTTP.Request -> m (HTTP.Response LByteString.ByteString)

--------------------------------------------------------------------------------
-- | Helper for rendering a URI as Text.
uriToText :: URI -> Text
uriToText :: URI -> Text
uriToText URI
uri = String -> Text
Text.pack ((String -> String) -> URI -> String -> String
uriToString String -> String
forall a. a -> a
id URI
uri [])

--------------------------------------------------------------------------------
-- | Force the given URI to use HTTPS.
forceHTTPS :: URI -> URI
forceHTTPS :: URI -> URI
forceHTTPS URI
uri = URI
uri { uriScheme :: String
uriScheme = String
"https:" }

--------------------------------------------------------------------------------
-- | Convert a URI or Text value into a pre-configured request object.
requestFromURI :: Either Text URI -> Maybe HTTP.Request
requestFromURI :: Either Text URI -> Maybe Request
requestFromURI (Left Text
t) = String -> Maybe URI
parseURI (Text -> String
Text.unpack Text
t) Maybe URI -> (URI -> Maybe Request) -> Maybe Request
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either Text URI -> Maybe Request
requestFromURI (Either Text URI -> Maybe Request)
-> (URI -> Either Text URI) -> URI -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> Either Text URI
forall a b. b -> Either a b
Right
requestFromURI (Right URI
uri) =
  URI -> Maybe Request
forall (m :: * -> *). MonadThrow m => URI -> m Request
HTTP.requestFromURI (URI -> URI
forceHTTPS URI
uri)
    Maybe Request -> (Request -> Request) -> Maybe Request
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString
"Accept", ByteString
"application/json")

--------------------------------------------------------------------------------
-- | Add a JSON body to a request.
jsonPostRequest :: ToJSON a => a -> HTTP.Request -> HTTP.Request
jsonPostRequest :: a -> Request -> Request
jsonPostRequest a
json Request
req = (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString
"Content-Type", ByteString
"application/json") (Request -> Request) -> Request -> Request
forall a b. (a -> b) -> a -> b
$
  Request
req { method :: ByteString
HTTP.method = ByteString
"POST"
      , requestBody :: RequestBody
HTTP.requestBody = ByteString -> RequestBody
HTTP.RequestBodyLBS (a -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode a
json)
      }

--------------------------------------------------------------------------------
-- | Add a header to the request.
addRequestHeader :: (CI ByteString, ByteString) -> HTTP.Request -> HTTP.Request
addRequestHeader :: (CI ByteString, ByteString) -> Request -> Request
addRequestHeader (CI ByteString, ByteString)
header Request
req =
  Request
req { requestHeaders :: RequestHeaders
HTTP.requestHeaders =
          (CI ByteString, ByteString)
header (CI ByteString, ByteString) -> RequestHeaders -> RequestHeaders
forall a. a -> [a] -> [a]
: ((CI ByteString, ByteString) -> Bool)
-> RequestHeaders -> RequestHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter ((CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= (CI ByteString, ByteString) -> CI ByteString
forall a b. (a, b) -> a
fst (CI ByteString, ByteString)
header) (CI ByteString -> Bool)
-> ((CI ByteString, ByteString) -> CI ByteString)
-> (CI ByteString, ByteString)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CI ByteString, ByteString) -> CI ByteString
forall a b. (a, b) -> a
fst) (Request -> RequestHeaders
HTTP.requestHeaders Request
req)
      }

--------------------------------------------------------------------------------
-- | Given a response, calculate how long it can be cached.
cacheUntil :: HTTP.Response a -> Maybe UTCTime
cacheUntil :: Response a -> Maybe UTCTime
cacheUntil Response a
res = Maybe UTCTime
maxAge Maybe UTCTime -> Maybe UTCTime -> Maybe UTCTime
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe UTCTime
expires
  where
    parseTime :: ByteString -> Maybe UTCTime
    parseTime :: ByteString -> Maybe UTCTime
parseTime = Bool -> TimeLocale -> String -> String -> Maybe UTCTime
forall (m :: * -> *) t.
(MonadFail m, ParseTime t) =>
Bool -> TimeLocale -> String -> String -> m t
parseTimeM Bool
True TimeLocale
defaultTimeLocale String
rfc1123 (String -> Maybe UTCTime)
-> (ByteString -> String) -> ByteString -> Maybe UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
Char8.unpack

    rfc1123 :: String
    rfc1123 :: String
rfc1123 = String
"%a, %d %b %Y %X %Z"

    date :: Maybe UTCTime
    date :: Maybe UTCTime
date = CI ByteString -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hDate (Response a -> RequestHeaders
forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res) Maybe ByteString -> (ByteString -> Maybe UTCTime) -> Maybe UTCTime
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe UTCTime
parseTime

    expires :: Maybe UTCTime
    expires :: Maybe UTCTime
expires = CI ByteString -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hExpires (Response a -> RequestHeaders
forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res) Maybe ByteString -> (ByteString -> Maybe UTCTime) -> Maybe UTCTime
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe UTCTime
parseTime

    maxAge :: Maybe UTCTime
    maxAge :: Maybe UTCTime
maxAge = do
      UTCTime
dt <- Maybe UTCTime
date
      ByteString
bs <- CI ByteString -> RequestHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
HTTP.hCacheControl (Response a -> RequestHeaders
forall body. Response body -> RequestHeaders
HTTP.responseHeaders Response a
res)
      ByteString
ma <- ByteString -> Maybe ByteString
nullM ((ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd (ByteString -> ByteString -> (ByteString, ByteString)
Char8.breakSubstring ByteString
"max-age" ByteString
bs))
      ByteString
bn <- ByteString -> Maybe ByteString
nullM ((ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((Char -> Bool) -> ByteString -> (ByteString, ByteString)
Char8.break Char -> Bool
isDigit ByteString
ma))
      NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (NominalDiffTime -> UTCTime -> UTCTime)
-> ((Int, ByteString) -> NominalDiffTime)
-> (Int, ByteString)
-> UTCTime
-> UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> NominalDiffTime)
-> ((Int, ByteString) -> Int)
-> (Int, ByteString)
-> NominalDiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, ByteString) -> Int
forall a b. (a, b) -> a
fst
        ((Int, ByteString) -> UTCTime -> UTCTime)
-> Maybe (Int, ByteString) -> Maybe (UTCTime -> UTCTime)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Maybe (Int, ByteString)
Char8.readInt (Int -> ByteString -> ByteString
Char8.take Int
6 ByteString
bn) -- Limit input to readInt
        Maybe (UTCTime -> UTCTime) -> Maybe UTCTime -> Maybe UTCTime
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UTCTime -> Maybe UTCTime
forall (f :: * -> *) a. Applicative f => a -> f a
pure UTCTime
dt

    nullM :: ByteString -> Maybe ByteString
    nullM :: ByteString -> Maybe ByteString
nullM ByteString
bs = if ByteString -> Bool
Char8.null ByteString
bs then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs

--------------------------------------------------------------------------------
-- | Decode the JSON body of a request and calculate how long it can
-- be cached.
parseResponse
  :: FromJSON a
  => HTTP.Response LByteString.ByteString
  -> Either ErrorResponse (a, Maybe UTCTime)
parseResponse :: Response ByteString -> Either ErrorResponse (a, Maybe UTCTime)
parseResponse Response ByteString
response =
  if Status -> Bool
HTTP.statusIsSuccessful (Response ByteString -> Status
forall body. Response body -> Status
HTTP.responseStatus Response ByteString
response)
    then ByteString -> Either String a
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
response) Either String a
-> (Either String a -> Either ErrorResponse (a, Maybe UTCTime))
-> Either ErrorResponse (a, Maybe UTCTime)
forall a b. a -> (a -> b) -> b
&
         (String -> ErrorResponse)
-> (a -> (a, Maybe UTCTime))
-> Either String a
-> Either ErrorResponse (a, Maybe UTCTime)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap String -> ErrorResponse
asError (,Response ByteString -> Maybe UTCTime
forall a. Response a -> Maybe UTCTime
cacheUntil Response ByteString
response)
    else ErrorResponse -> Either ErrorResponse (a, Maybe UTCTime)
forall a b. a -> Either a b
Left (String -> ErrorResponse
asError String
"invalid response from server")
  where
    asError :: String -> ErrorResponse
    asError :: String -> ErrorResponse
asError String
s = case ByteString -> Either String ErrorResponse
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody Response ByteString
response) of
      Left String
_  -> Text -> Maybe Text -> ErrorResponse
ErrorResponse (String -> Text
Text.pack String
s) (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
bodyError)
      Right ErrorResponse
e -> ErrorResponse
e

    bodyError :: Text
    bodyError :: Text
bodyError = Response ByteString
response
              Response ByteString
-> (Response ByteString -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& Response ByteString -> ByteString
forall body. Response body -> body
HTTP.responseBody
              ByteString -> (ByteString -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& Int64 -> ByteString -> ByteString
LChar8.take Int64
1024
              ByteString -> (ByteString -> ByteString) -> ByteString
forall a b. a -> (a -> b) -> b
& ByteString -> ByteString
LChar8.toStrict
              ByteString -> (ByteString -> Text) -> Text
forall a b. a -> (a -> b) -> b
& ByteString -> Text
Text.decodeUtf8