{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE CPP #-}

-- | Servant client authentication.
module Servant.Auth.Hmac.Client (
    -- * HMAC client settings
    HmacSettings (..),
    defaultHmacSettings,

    -- * HMAC servant client
    HmacClientM (..),
    runHmacClient,
    hmacClient,
) where

import Control.Monad ((>=>))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Reader (MonadReader (..), ReaderT, asks, runReaderT)
import Control.Monad.Trans.Class (lift)
import Data.ByteString (ByteString)
import Data.CaseInsensitive (mk)
import Data.Foldable (toList)
import Data.List (sort)
import Data.Proxy (Proxy (..))
import Data.Sequence (fromList, (<|))
import Data.String (fromString)
import Servant.Client (
    BaseUrl,
    Client,
    ClientEnv (baseUrl),
    ClientError,
    ClientM,
    HasClient,
    runClientM,
 )
import Servant.Client.Core (RunClient (..), clientIn)
import Servant.Client.Internal.HttpClient (defaultMakeClientRequest)

import Servant.Auth.Hmac.Crypto (
    RequestPayload (..),
    SecretKey,
    Signature (..),
    authHeaderName,
    keepWhitelistedHeaders,
    requestSignature,
    signSHA256,
 )

import qualified Network.HTTP.Client as Client
import qualified Servant.Client.Core as Servant

-- | Environment for 'HmacClientM'. Contains all required settings for hmac client.
data HmacSettings = HmacSettings
    { HmacSettings -> SecretKey -> ByteString -> Signature
hmacSigner :: SecretKey -> ByteString -> Signature
    -- ^ Singing function that will sign all outgoing requests.
    , HmacSettings -> SecretKey
hmacSecretKey :: SecretKey
    -- ^ Secret key for signing function.
    , HmacSettings -> Maybe (Request -> ClientM ())
hmacRequestHook :: Maybe (Servant.Request -> ClientM ())
    -- ^ Function to call for every request after this request is signed.
    -- Useful for debugging.
    }

{- | Default 'HmacSettings' with the following configuration:

1. Signing function is 'signSHA256'.
2. Secret key is provided.
3. 'hmacRequestHook' is 'Nothing'.
-}
defaultHmacSettings :: SecretKey -> HmacSettings
defaultHmacSettings :: SecretKey -> HmacSettings
defaultHmacSettings SecretKey
sk =
    HmacSettings
        { hmacSigner :: SecretKey -> ByteString -> Signature
hmacSigner = SecretKey -> ByteString -> Signature
signSHA256
        , hmacSecretKey :: SecretKey
hmacSecretKey = SecretKey
sk
        , hmacRequestHook :: Maybe (Request -> ClientM ())
hmacRequestHook = Maybe (Request -> ClientM ())
forall a. Maybe a
Nothing
        }

{- | @newtype@ wrapper over 'ClientM' that signs all outgoing requests
automatically.
-}
newtype HmacClientM a = HmacClientM
    { forall a. HmacClientM a -> ReaderT HmacSettings ClientM a
runHmacClientM :: ReaderT HmacSettings ClientM a
    }
    deriving ((forall a b. (a -> b) -> HmacClientM a -> HmacClientM b)
-> (forall a b. a -> HmacClientM b -> HmacClientM a)
-> Functor HmacClientM
forall a b. a -> HmacClientM b -> HmacClientM a
forall a b. (a -> b) -> HmacClientM a -> HmacClientM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> HmacClientM a -> HmacClientM b
fmap :: forall a b. (a -> b) -> HmacClientM a -> HmacClientM b
$c<$ :: forall a b. a -> HmacClientM b -> HmacClientM a
<$ :: forall a b. a -> HmacClientM b -> HmacClientM a
Functor, Functor HmacClientM
Functor HmacClientM =>
(forall a. a -> HmacClientM a)
-> (forall a b.
    HmacClientM (a -> b) -> HmacClientM a -> HmacClientM b)
-> (forall a b c.
    (a -> b -> c) -> HmacClientM a -> HmacClientM b -> HmacClientM c)
-> (forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b)
-> (forall a b. HmacClientM a -> HmacClientM b -> HmacClientM a)
-> Applicative HmacClientM
forall a. a -> HmacClientM a
forall a b. HmacClientM a -> HmacClientM b -> HmacClientM a
forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
forall a b. HmacClientM (a -> b) -> HmacClientM a -> HmacClientM b
forall a b c.
(a -> b -> c) -> HmacClientM a -> HmacClientM b -> HmacClientM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> HmacClientM a
pure :: forall a. a -> HmacClientM a
$c<*> :: forall a b. HmacClientM (a -> b) -> HmacClientM a -> HmacClientM b
<*> :: forall a b. HmacClientM (a -> b) -> HmacClientM a -> HmacClientM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> HmacClientM a -> HmacClientM b -> HmacClientM c
liftA2 :: forall a b c.
(a -> b -> c) -> HmacClientM a -> HmacClientM b -> HmacClientM c
$c*> :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
*> :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
$c<* :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM a
<* :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM a
Applicative, Applicative HmacClientM
Applicative HmacClientM =>
(forall a b.
 HmacClientM a -> (a -> HmacClientM b) -> HmacClientM b)
-> (forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b)
-> (forall a. a -> HmacClientM a)
-> Monad HmacClientM
forall a. a -> HmacClientM a
forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
forall a b. HmacClientM a -> (a -> HmacClientM b) -> HmacClientM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. HmacClientM a -> (a -> HmacClientM b) -> HmacClientM b
>>= :: forall a b. HmacClientM a -> (a -> HmacClientM b) -> HmacClientM b
$c>> :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
>> :: forall a b. HmacClientM a -> HmacClientM b -> HmacClientM b
$creturn :: forall a. a -> HmacClientM a
return :: forall a. a -> HmacClientM a
Monad, Monad HmacClientM
Monad HmacClientM =>
(forall a. IO a -> HmacClientM a) -> MonadIO HmacClientM
forall a. IO a -> HmacClientM a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall a. IO a -> HmacClientM a
liftIO :: forall a. IO a -> HmacClientM a
MonadIO, MonadReader HmacSettings)

hmacifyClient :: ClientM a -> HmacClientM a
hmacifyClient :: forall a. ClientM a -> HmacClientM a
hmacifyClient = ReaderT HmacSettings ClientM a -> HmacClientM a
forall a. ReaderT HmacSettings ClientM a -> HmacClientM a
HmacClientM (ReaderT HmacSettings ClientM a -> HmacClientM a)
-> (ClientM a -> ReaderT HmacSettings ClientM a)
-> ClientM a
-> HmacClientM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientM a -> ReaderT HmacSettings ClientM a
forall (m :: * -> *) a. Monad m => m a -> ReaderT HmacSettings m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

hmacClientSign :: Servant.Request -> HmacClientM Servant.Request
hmacClientSign :: Request -> HmacClientM Request
hmacClientSign Request
req = ReaderT HmacSettings ClientM Request -> HmacClientM Request
forall a. ReaderT HmacSettings ClientM a -> HmacClientM a
HmacClientM (ReaderT HmacSettings ClientM Request -> HmacClientM Request)
-> ReaderT HmacSettings ClientM Request -> HmacClientM Request
forall a b. (a -> b) -> a -> b
$ do
    HmacSettings{Maybe (Request -> ClientM ())
SecretKey
SecretKey -> ByteString -> Signature
hmacSigner :: HmacSettings -> SecretKey -> ByteString -> Signature
hmacSecretKey :: HmacSettings -> SecretKey
hmacRequestHook :: HmacSettings -> Maybe (Request -> ClientM ())
hmacSigner :: SecretKey -> ByteString -> Signature
hmacSecretKey :: SecretKey
hmacRequestHook :: Maybe (Request -> ClientM ())
..} <- ReaderT HmacSettings ClientM HmacSettings
forall r (m :: * -> *). MonadReader r m => m r
ask
    BaseUrl
url <- ClientM BaseUrl -> ReaderT HmacSettings ClientM BaseUrl
forall (m :: * -> *) a. Monad m => m a -> ReaderT HmacSettings m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClientM BaseUrl -> ReaderT HmacSettings ClientM BaseUrl)
-> ClientM BaseUrl -> ReaderT HmacSettings ClientM BaseUrl
forall a b. (a -> b) -> a -> b
$ (ClientEnv -> BaseUrl) -> ClientM BaseUrl
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ClientEnv -> BaseUrl
baseUrl
    Request
signedRequest <- IO Request -> ReaderT HmacSettings ClientM Request
forall a. IO a -> ReaderT HmacSettings ClientM a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> ReaderT HmacSettings ClientM Request)
-> IO Request -> ReaderT HmacSettings ClientM Request
forall a b. (a -> b) -> a -> b
$ (SecretKey -> ByteString -> Signature)
-> SecretKey -> BaseUrl -> Request -> IO Request
signRequestHmac SecretKey -> ByteString -> Signature
hmacSigner SecretKey
hmacSecretKey BaseUrl
url Request
req
    case Maybe (Request -> ClientM ())
hmacRequestHook of
        Maybe (Request -> ClientM ())
Nothing -> () -> ReaderT HmacSettings ClientM ()
forall a. a -> ReaderT HmacSettings ClientM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just Request -> ClientM ()
hook -> ClientM () -> ReaderT HmacSettings ClientM ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT HmacSettings m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClientM () -> ReaderT HmacSettings ClientM ())
-> ClientM () -> ReaderT HmacSettings ClientM ()
forall a b. (a -> b) -> a -> b
$ Request -> ClientM ()
hook Request
signedRequest
    Request -> ReaderT HmacSettings ClientM Request
forall a. a -> ReaderT HmacSettings ClientM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Request
signedRequest

instance RunClient HmacClientM where
    runRequestAcceptStatus :: Maybe [Status] -> Request -> HmacClientM Response
runRequestAcceptStatus Maybe [Status]
s = Request -> HmacClientM Request
hmacClientSign (Request -> HmacClientM Request)
-> (Request -> HmacClientM Response)
-> Request
-> HmacClientM Response
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> ClientM Response -> HmacClientM Response
forall a. ClientM a -> HmacClientM a
hmacifyClient (ClientM Response -> HmacClientM Response)
-> (Request -> ClientM Response) -> Request -> HmacClientM Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe [Status] -> Request -> ClientM Response
forall (m :: * -> *).
RunClient m =>
Maybe [Status] -> Request -> m Response
runRequestAcceptStatus Maybe [Status]
s

    throwClientError :: ClientError -> HmacClientM a
    throwClientError :: forall a. ClientError -> HmacClientM a
throwClientError = ClientM a -> HmacClientM a
forall a. ClientM a -> HmacClientM a
hmacifyClient (ClientM a -> HmacClientM a)
-> (ClientError -> ClientM a) -> ClientError -> HmacClientM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientError -> ClientM a
forall a. ClientError -> ClientM a
forall (m :: * -> *) a. RunClient m => ClientError -> m a
throwClientError

runHmacClient ::
    HmacSettings ->
    ClientEnv ->
    HmacClientM a ->
    IO (Either ClientError a)
runHmacClient :: forall a.
HmacSettings
-> ClientEnv -> HmacClientM a -> IO (Either ClientError a)
runHmacClient HmacSettings
settings ClientEnv
env HmacClientM a
client =
    ClientM a -> ClientEnv -> IO (Either ClientError a)
forall a. ClientM a -> ClientEnv -> IO (Either ClientError a)
runClientM (ReaderT HmacSettings ClientM a -> HmacSettings -> ClientM a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (HmacClientM a -> ReaderT HmacSettings ClientM a
forall a. HmacClientM a -> ReaderT HmacSettings ClientM a
runHmacClientM HmacClientM a
client) HmacSettings
settings) ClientEnv
env

-- | Generates a set of client functions for an API.
hmacClient :: forall api. HasClient HmacClientM api => Client HmacClientM api
hmacClient :: forall api. HasClient HmacClientM api => Client HmacClientM api
hmacClient = forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @api Proxy api -> Proxy HmacClientM -> Client HmacClientM api
forall (m :: * -> *) api.
HasClient m api =>
Proxy api -> Proxy m -> Client m api
`clientIn` forall {k} (t :: k). Proxy t
forall (t :: * -> *). Proxy t
Proxy @HmacClientM

----------------------------------------------------------------------------
-- Internals
----------------------------------------------------------------------------

servantRequestToPayload :: BaseUrl -> Servant.Request -> IO RequestPayload
servantRequestToPayload :: BaseUrl -> Request -> IO RequestPayload
servantRequestToPayload BaseUrl
url Request
sreq = do
#if MIN_VERSION_servant_client(0,20,0)
    Request
req <- -- servant-client 0.20: defaultMakeClientRequest :: BaseUrl -> Request -> IO Request
#else
    let req = -- servant-client 0.19: defaultMakeClientRequest :: BaseUrl -> Request -> Request
#endif
            BaseUrl -> Request -> IO Request
defaultMakeClientRequest BaseUrl
url Request
sreq
                { Servant.requestQueryString =
                    fromList $ sort $ toList $ Servant.requestQueryString sreq
                }

    let
        hostAndPort :: ByteString
        hostAndPort :: ByteString
hostAndPort = case CI ByteString -> [(CI ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
"Host") (Request -> [(CI ByteString, ByteString)]
Client.requestHeaders Request
req) of
            Just ByteString
hp -> ByteString
hp
            Maybe ByteString
Nothing ->
                case (Request -> Bool
Client.secure Request
req, Request -> Int
Client.port Request
req) of
                    (Bool
True, Int
443) -> Request -> ByteString
Client.host Request
req
                    (Bool
False, Int
80) -> Request -> ByteString
Client.host Request
req
                    (Bool
_, Int
p) -> Request -> ByteString
Client.host Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
":" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
forall a. IsString a => String -> a
fromString (Int -> String
forall a. Show a => a -> String
show Int
p)

    RequestPayload -> IO RequestPayload
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RequestPayload
        { rpMethod :: ByteString
rpMethod = Request -> ByteString
Client.method Request
req
        , rpContent :: ByteString
rpContent = ByteString
"" -- toBsBody $ Client.requestBody req
        , rpHeaders :: [(CI ByteString, ByteString)]
rpHeaders =
            [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
keepWhitelistedHeaders ([(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)])
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$
                (CI ByteString
"Host", ByteString
hostAndPort) (CI ByteString, ByteString)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. a -> [a] -> [a]
:
                (CI ByteString
"Accept-Encoding", ByteString
"gzip") (CI ByteString, ByteString)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. a -> [a] -> [a]
:
                Request -> [(CI ByteString, ByteString)]
Client.requestHeaders Request
req
        , rpRawUrl :: ByteString
rpRawUrl = ByteString
hostAndPort ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Client.path Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Client.queryString Request
req
        }

--    toBsBody :: RequestBody -> ByteString
--    toBsBody (RequestBodyBS bs)       = bs
--    toBsBody (RequestBodyLBS bs)      = LBS.toStrict bs
--    toBsBody (RequestBodyBuilder _ b) = LBS.toStrict $ toLazyByteString b
--    toBsBody _                        = ""  -- heh

{- | Adds signed header to the request.

@
Authentication: HMAC <signature>
@
-}
signRequestHmac ::
    -- | Signing function
    (SecretKey -> ByteString -> Signature) ->
    -- | Secret key that was used for signing 'Request'
    SecretKey ->
    -- | Base url for servant request
    BaseUrl ->
    -- | Original request
    Servant.Request ->
    -- | Signed request
    IO Servant.Request
signRequestHmac :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> BaseUrl -> Request -> IO Request
signRequestHmac SecretKey -> ByteString -> Signature
signer SecretKey
sk BaseUrl
url Request
req = do
    RequestPayload
payload <- BaseUrl -> Request -> IO RequestPayload
servantRequestToPayload BaseUrl
url Request
req
    let signature :: Signature
signature = (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Signature
requestSignature SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
payload
    let authHead :: (CI ByteString, ByteString)
authHead = (CI ByteString
authHeaderName, ByteString
"HMAC " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Signature -> ByteString
unSignature Signature
signature)
    Request -> IO Request
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Request
req{Servant.requestHeaders = authHead <| Servant.requestHeaders req}