diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index c0319c3e9..b5ad2c8c1 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -9,6 +9,7 @@ ### Bugs Fixed - Fixed warning for an unused function in curl.cpp when building the SDK using a version of libcurl older than 7.77.0. +- Invalidate the token cache within `BearerTokenAuthenticationPolicy` whenever a token request comes back with a 401 response. ### Other Changes diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index 713bcdd74..c7b54b1bb 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -576,6 +576,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { mutable Credentials::AccessToken m_accessToken; mutable std::shared_timed_mutex m_accessTokenMutex; mutable Credentials::TokenRequestContext m_accessTokenContext; + mutable std::atomic m_invalidateToken = {false}; public: /** @@ -610,6 +611,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { std::shared_lock readLock(other.m_accessTokenMutex); m_accessToken = other.m_accessToken; m_accessTokenContext = other.m_accessTokenContext; + m_invalidateToken.store(other.m_invalidateToken.load()); } void operator=(BearerTokenAuthenticationPolicy const&) = delete; diff --git a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp index 25d558486..aa2adad4b 100644 --- a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp +++ b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp @@ -31,6 +31,7 @@ std::unique_ptr BearerTokenAuthenticationPolicy::Send( auto result = AuthorizeAndSendRequest(request, nextPolicy, context); { auto const& response = *result; + m_invalidateToken = (response.GetStatusCode() == HttpStatusCode::Unauthorized); auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response); if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context)) { @@ -67,9 +68,10 @@ bool TokenNeedsRefresh( Azure::Core::Credentials::AccessToken const& cachedToken, Azure::Core::Credentials::TokenRequestContext const& cachedTokenRequestContext, Azure::DateTime const& currentTime, - Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext) + Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext, + bool forceRefresh) { - return newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId + return forceRefresh || newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId || newTokenRequestContext.Scopes != cachedTokenRequestContext.Scopes || currentTime > (cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration); } @@ -91,7 +93,12 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest( { std::shared_lock readLock(m_accessTokenMutex); - if (!TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext)) + if (!TokenNeedsRefresh( + m_accessToken, + m_accessTokenContext, + currentTime, + tokenRequestContext, + m_invalidateToken)) { ApplyBearerToken(request, m_accessToken); return; @@ -100,10 +107,20 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest( std::unique_lock writeLock(m_accessTokenMutex); // Check if token needs refresh for the second time in case another thread has just updated it. - if (TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext)) + if (TokenNeedsRefresh( + m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext, m_invalidateToken)) { - m_accessToken = m_credential->GetToken(tokenRequestContext, context); + TokenRequestContext trcCopy = tokenRequestContext; + if (m_invalidateToken) + { + // Need to set this to invalidate the credential's token cache to ensure we fetch a new token + // on subsequent GetToken calls. + trcCopy.MinimumExpiration = DateTime::duration::max(); + } + + m_accessToken = m_credential->GetToken(trcCopy, context); m_accessTokenContext = tokenRequestContext; + m_invalidateToken = false; } ApplyBearerToken(request, m_accessToken); diff --git a/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp b/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp index 6f09de0b6..fbe20ec82 100644 --- a/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp @@ -52,6 +52,37 @@ public: return std::make_unique(*this); } }; + +class TestTransportPolicyMultipleResponses final : public HttpPolicy { +private: + mutable int m_responsesCount = 0; + +public: + std::unique_ptr Send(Request&, NextHttpPolicy, Context const&) const override + { + if (m_responsesCount == 1) + { + m_responsesCount++; + return std::make_unique(1, 1, HttpStatusCode::Unauthorized, "TestStatus"); + } + if (m_responsesCount == 2) + { + m_responsesCount++; + return std::make_unique(1, 1, HttpStatusCode::Ok, "TestStatus"); + } + if (m_responsesCount > 2) + { + EXPECT_TRUE(false); + } + m_responsesCount++; + return std::make_unique(1, 1, HttpStatusCode::Ok, "TestStatus"); + } + + std::unique_ptr Clone() const override + { + return std::make_unique(*this); + } +}; } // namespace TEST(BearerTokenAuthenticationPolicy, InitialGet) @@ -169,6 +200,66 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry) } } +TEST(BearerTokenAuthenticationPolicy, TokenInvalidatedAfterUnauth) +{ + using namespace std::chrono_literals; + auto accessToken = std::make_shared(); + + std::vector> policies; + + TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; + + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); + + policies.emplace_back(std::make_unique()); + + HttpPipeline pipeline(policies); + + // The first request is successful, the token gets cached in the credential + { + Request request(HttpMethod::Get, Url("https://www.azure.com")); + + *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h}; + + pipeline.Send(request, Context()); + + { + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1"); + } + } + + // The second request returns unauthorized, the token should be invalidated + { + Request request(HttpMethod::Get, Url("https://www.azure.com")); + + *accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h}; + + pipeline.Send(request, Context()); + + { + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1"); + } + + // We expect the next call to return a new token + pipeline.Send(request, Context()); + + { + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2"); + } + } +} + TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry) { using namespace std::chrono_literals;