{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Security.AccessTokenProvider.Internal.Providers.OAuth2.Ropcg ( probeProviderRopcg ) where import Control.Exception.Safe import Control.Lens import Control.Monad import Control.Monad.IO.Class import Control.Monad.IO.Unlift import Data.Aeson import qualified Data.ByteString as ByteString import qualified Data.ByteString.Base64 as B64 import Data.Format import qualified Data.Map as Map import Data.Maybe import Data.Monoid import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import Network.HTTP.Client import Network.HTTP.Client.TLS import Network.HTTP.Types import qualified System.Environment as Env import System.FilePath import System.Random import UnliftIO.Async import UnliftIO.Concurrent import UnliftIO.STM import qualified Security.AccessTokenProvider.Internal.Lenses as L import Security.AccessTokenProvider.Internal.Types import qualified Security.AccessTokenProvider.Internal.Types.Severity as Severity import Security.AccessTokenProvider.Internal.Util -- | Access Token Provider prober for access token retrieval via -- OAuth2 Resource-Owner-Password-Credentails-Grant. probeProviderRopcg :: (MonadMask m, MonadUnliftIO m) => AtpProbe m probeProviderRopcg = AtpProbe probeProvider probeProvider :: (MonadMask m, MonadUnliftIO m) => Backend m -> AccessTokenName -> m (Maybe (AccessTokenProvider m t)) probeProvider backend tokenName = do let BackendLog { .. } = backendLog backend BackendEnv { .. } = backendEnv backend logAddNamespace "probe-ropcg" $ do envLookup "ATP_CONF_ROPCG" >>= \ case Just confS -> do logMsg Severity.Info [fmt|Trying access token provider 'ropcg'|] throwDecode (Text.encodeUtf8 confS) >>= createRopcgConf >>= tryCreateProvider backend tokenName Nothing -> pure Nothing -- | Derive an authorization header from provided client credentials. makeBasicAuthorizationHeader :: ClientCredentials -> Header makeBasicAuthorizationHeader credentials = let b64Token = [credentials^.L.clientId, credentials^.L.clientSecret] & map Text.encodeUtf8 & ByteString.intercalate ":" & B64.encode in ("Authorization", "Basic " <> b64Token) retrieveJson :: (FromJSON a, MonadCatch m) => Backend m -> FilePath -> m a retrieveJson backend filename = do let fsBackend = backendFilesystem backend BackendLog { .. } = backendLog backend content <- fileRead fsBackend filename case eitherDecodeStrict content of Right a -> return a Left errMsgStr -> do let errMsg = Text.pack errMsgStr logMsg Severity.Error [fmt|JSON deserialization error: $errMsg|] throwM . AccessTokenProviderDeserialization $ [fmt|Failed to deserialize '${filename}': $errMsg|] -- | Retrieve credentials from credentials directory. retrieveCredentials :: MonadCatch m => Backend m -> AtpConfRopcg -> m Credentials retrieveCredentials backend conf = do let baseDir = conf^.L.credentialsDirectory userCred <- retrieveJson backend (prefixIfRelative baseDir (conf^.L.resourceOwnerPasswordFile)) clientCred <- retrieveJson backend (prefixIfRelative baseDir (conf^.L.clientPasswordFile)) return Credentials { _user = userCred , _client = clientCred } where prefixIfRelative baseDir filename = if isAbsolute filename then filename else baseDir filename -- | Environment variable expected to contain the path to the mint -- credentials. envCredentialsDirectory :: String envCredentialsDirectory = "CREDENTIALS_DIR" retrieveCredentialsDir :: (MonadIO m, MonadThrow m) => AtpPreconfRopcg -> m FilePath retrieveCredentialsDir envConf = case envConf^.L.credentialsDirectory of Just dir -> pure dir Nothing -> liftIO (fromMaybe "." <$> Env.lookupEnv envCredentialsDirectory) createRopcgConf :: (MonadIO m, MonadCatch m) => AtpPreconfRopcg -> m AtpConfRopcg createRopcgConf envConf = do authEndpoint <- parseEndpoint (envConf^.L.authEndpoint) credentialsDirectory <- retrieveCredentialsDir envConf let clientPasswordFile = fromMaybe defaultClientPasswordFile (envConf^.L.clientPasswordFile) let resourceOwnerPasswordFile = fromMaybe defaultResourceOwnerPasswordFile (envConf^.L.resourceOwnerPasswordFile) let refreshTimeFactor = fromMaybe defaultRefreshTimeFactor (envConf^.L.refreshTimeFactor) manager <- liftIO $ newManager tlsManagerSettings pure AtpConfRopcg { _credentialsDirectory = credentialsDirectory , _clientPasswordFile = clientPasswordFile , _resourceOwnerPasswordFile = resourceOwnerPasswordFile , _refreshTimeFactor = refreshTimeFactor , _authEndpoint = authEndpoint , _manager = manager , _tokens = envConf^.L.tokens } where defaultResourceOwnerPasswordFile = "user.json" defaultClientPasswordFile = "client.json" -- | Main refreshing function. tryRefreshToken :: MonadCatch m => Backend m -> AtpConfRopcg -> AccessTokenName -> AtpRopcgTokenDef -> m AtpRopcgResponse tryRefreshToken backend conf tokenName tokenDef = logAddNamespace "refreshActionOne" $ do credentials <- retrieveCredentials backend conf let httpBackend = backendHttp backend bodyParameters = [ ("grant_type", "password") , ("username", Text.encodeUtf8 (credentials^.L.user.L.applicationUsername)) , ("password", Text.encodeUtf8 (credentials^.L.user.L.applicationPassword)) , ("scope", packScopes (tokenDef^.L.scopes)) ] authorization = makeBasicAuthorizationHeader (credentials^.L.client) httpRequest = (conf^.L.authEndpoint) { method = "POST" , requestHeaders = [authorization] } & urlEncodedBody bodyParameters logMsg Severity.Debug [fmt|HTTP Request for token refreshing: ${tshow httpRequest}|] response <- httpRequestExecute httpBackend httpRequest let status = responseStatus response body = responseBody response when (status /= ok200) $ do logMsg Severity.Error [fmt|Failed to refresh token: ${tshow response}|] throwM $ decodeOAuth2Error status body case eitherDecode body :: Either String AtpRopcgResponse of Right tokenResponse -> do logMsg Severity.Debug [fmt|Successfully refreshed token '${tokenName}'|] pure tokenResponse Left errMsgS -> do let errMsg = Text.pack errMsgS logMsg Severity.Error [fmt|Deserialization of token response failed: $errMsg|] throwM $ AccessTokenProviderDeserialization errMsg where packScopes = ByteString.intercalate " " . map Text.encodeUtf8 BackendLog { .. } = backendLog backend decodeOAuth2Error status body = case decode body of Just problem -> AccessTokenProviderRefreshFailure problem Nothing -> AccessTokenProviderDeserialization $ [fmt|Deserialization of OAuth2 error object failed; response status: ${tshow status}'|] tokenRefreshLoop :: forall m t . (MonadCatch m, MonadIO m) => Backend m -> AtpConfRopcg -> AccessTokenName -> AtpRopcgTokenDef -> TMVar (Either SomeException (AccessToken t)) -> m () tokenRefreshLoop backend conf tokenName tokenDef cache = forever $ do eitherTokenResponse <- tryAny (tryRefreshToken backend conf tokenName tokenDef) atomically $ do let eitherToken = AccessToken . view L.accessToken <$> eitherTokenResponse isEmptyTMVar cache >>= \ case True -> putTMVar cache eitherToken False -> void $ swapTMVar cache eitherToken secondsToWait <- computeDurationToWait eitherTokenResponse let microsToWait = round $ secondsToWait * 10^(6 :: Int) threadDelay microsToWait where -- Returns duration in seconds. computeDurationToWait :: Either SomeException AtpRopcgResponse -> m Double computeDurationToWait eitherTokenResponse = case eitherTokenResponse of Right tokenResponse -> case tokenResponse^.L.expiresIn of Just expiresIn -> pure $ conf^.L.refreshTimeFactor * fromIntegral expiresIn Nothing -> pure defaultRefreshInterval Left exn -> do logMsg Severity.Error [fmt|Failed to refresh token '${tokenName}': $exn|] liftIO $ randomRIO (1, 10) -- Some jitter: wait 1 - 10 seconds. BackendLog { .. } = backendLog backend -- | In seconds. defaultRefreshInterval :: Double defaultRefreshInterval = 60 -- | By default, we start refreshing tokens after 80% of the -- "expires_in" time of a token has been elapsed. defaultRefreshTimeFactor :: Double defaultRefreshTimeFactor = 0.8 tryCreateProvider :: (MonadUnliftIO m, MonadMask m) => Backend m -> AccessTokenName -> AtpConfRopcg -> m (Maybe (AccessTokenProvider m t)) tryCreateProvider backend tokenName conf = do let (AccessTokenName tokenNameText) = tokenName BackendLog { .. } = backendLog backend maybeTokenDef = Map.lookup tokenNameText (conf^.L.tokens) case maybeTokenDef of Just tokenDef -> do logMsg Severity.Info [fmt|AccessTokenProvider starting|] provider <- newProvider tokenDef pure (Just provider) Nothing -> pure Nothing where newProvider tokenDef = do (retrieveAction, releaseAction) <- newRetrieveAction backend conf tokenName tokenDef pure AccessTokenProvider { retrieveAccessToken = retrieveAction , releaseProvider = releaseAction } newRetrieveAction :: (MonadUnliftIO m, MonadCatch m) => Backend m -> AtpConfRopcg -> AccessTokenName -> AtpRopcgTokenDef -> m (m (AccessToken t), m ()) newRetrieveAction backend conf tokenName tokenDef = do cache <- atomically newEmptyTMVar asyncHandle <- async $ tokenRefreshLoop backend conf tokenName tokenDef cache link asyncHandle pure $ do let retrieveAction = atomically (readTMVar cache) >>= \ case Right token -> pure token Left exn -> throwM exn releaseAction = cancel asyncHandle (retrieveAction, releaseAction)