-- | -- Authorization grant flow implementation. You probably want 'Network.OAuth2.JWT.Client'. -- {-# LANGUAGE OverloadedStrings #-} module Network.OAuth2.JWT.Client.AuthorizationGrant ( GrantError (..) , sign , refresh , local , grant ) where import qualified Control.Concurrent.MVar as MVar import Control.Lens ((.~), (&)) import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.Trans.Bifunctor (BifunctorTrans (..)) import Control.Monad.Trans.Except (ExceptT (..), runExceptT) import Crypto.JWT (JWK, JWTError) import qualified Crypto.JWT as JWT import qualified Data.Aeson as Aeson import Data.Bifunctor as X (Bifunctor(..)) import qualified Data.ByteString.Lazy as LazyByteString import qualified Data.HashMap.Strict as HashMap import Data.String (IsString (..)) import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import Data.Time (UTCTime) import qualified Data.Time as Time import Network.OAuth2.JWT.Client.Data import qualified Network.OAuth2.JWT.Client.Serial as Serial import qualified Network.HTTP.Client as HTTP import qualified Network.HTTP.Types as HTTP data GrantError = SerialisationGrantError Text | JWTGrantError JWT.JWTError | EndpointGrantError Text | StatusGrantError Int Text deriving (Eq, Show) -- | -- Obtain an access token, if we have already aquired one (and -- it is still valid) we will re-use that token, if we don't -- already have a token or the token has expired, we go and -- ask for a new one. -- -- This operation is safe to call from multiple threads. If we are -- using a current token reads will happen concurrently, If we have to -- go to the network the request will be serialised so that only one -- request is made for a new token. -- grant :: Store -> IO (Either GrantError AccessToken) grant (Store manager endpoint claims jwk store) = do now <- Time.getCurrentTime t <- local now <$> MVar.readMVar store case t of Just token -> pure . Right $ token Nothing -> do MVar.modifyMVar store $ \state -> do case local now state of Just token -> pure (state, Right token) Nothing -> runExceptT (refresh now manager endpoint claims jwk) >>= \e -> case e of Left err -> pure (state, Left err) Right (Response token expiry) -> pure (HasToken token (Time.addUTCTime (getExpiresIn expiry) now), Right token) -- | -- Obtain an already aquired access token iff it is still valid. -- local :: UTCTime -> TokenState -> Maybe AccessToken local now state = case state of HasToken token time | now < time -> Just token HasToken _ _ -> Nothing NoToken -> Nothing -- | -- Request a new access token as per the specified claims. -- -- This request is defined in of -- . -- refresh :: UTCTime -> HTTP.Manager -> TokenEndpoint -> Claims -> JWK -> ExceptT GrantError IO Response refresh now manager endpoint claims jwk = do assertion <- firstT JWTGrantError $ sign now claims jwk req <- ExceptT . pure . first (EndpointGrantError . Text.pack . show) $ HTTP.parseRequest (Text.unpack . getTokenEndpoint $ endpoint) res <- liftIO $ flip HTTP.httpLbs manager $ HTTP.urlEncodedBody [ ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") , ("assertion", getAssertion assertion) ] $ req { HTTP.requestHeaders = [ ("Accept", "application/json") ] } case HTTP.statusCode . HTTP.responseStatus $ res of 200 -> ExceptT . pure . first SerialisationGrantError $ Serial.response (HTTP.responseBody res) status -> ExceptT . pure . Left $ StatusGrantError status (Text.decodeUtf8 . LazyByteString.toStrict . HTTP.responseBody $ res) -- | -- Sign a JWT with the specified claims and key. -- -- The format and signature of the JWT are defined by -- . -- -- The specific of the claims are defined by the OAuth2 -- JWT Profile . -- sign :: UTCTime -> Claims -> JWK -> ExceptT JWTError IO Assertion sign now (Claims issuer subject audience scopes expires custom) jwk = do let format = fromString . Text.unpack header = JWT.newJWSHeader ((), JWT.RS256) & JWT.typ .~ Just (JWT.HeaderParam () "JWT") claims = JWT.emptyClaimsSet & JWT.claimIss .~ Just (format . getIssuer $ issuer) & JWT.claimSub .~ fmap (format . getSubject) subject & JWT.claimAud .~ Just (JWT.Audience [format . getAudience $ audience]) & JWT.claimIat .~ Just (JWT.NumericDate now) & JWT.claimExp .~ Just (JWT.NumericDate $ Time.addUTCTime (getExpiresIn expires) now) & JWT.unregisteredClaims .~ (HashMap.fromList $ [ ("scope", Aeson.toJSON . Text.intercalate " " $ getScope <$> scopes) ] ++ custom) signed <- JWT.signClaims jwk header claims pure . Assertion . LazyByteString.toStrict . JWT.encodeCompact $ signed