Set token to be expired if response comes back as unauthorized, within BeareTokenAuthenticationPolicy. (#6151)

* Set token to be expired if response comes back as unauthorized.

* Add CL entry.

* Update CL.

* Use trc MinimumExpiration to invalidate the credential's token cache.

* Add test.

* Address PR feedback.

* Remove comment as it is no longer relevant.

* Use initializer list syntax to see if posix compilers are okay with that.

* Keep the bool field as non-atomic.

* Revert "Keep the bool field as non-atomic."

This reverts commit 1b8c7622d5234b010bb0a4eb5db8a436de5a2adf.
This commit is contained in:
Ahson Khan 2024-10-30 14:00:42 -07:00 committed by GitHub
parent c168d736dd
commit 064fcad72f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 116 additions and 5 deletions

View File

@ -9,6 +9,7 @@
### Bugs Fixed
- Fixed warning for an unused function in curl.cpp when building the SDK using a version of libcurl older than 7.77.0.
- Invalidate the token cache within `BearerTokenAuthenticationPolicy` whenever a token request comes back with a 401 response.
### Other Changes

View File

@ -576,6 +576,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
mutable Credentials::AccessToken m_accessToken;
mutable std::shared_timed_mutex m_accessTokenMutex;
mutable Credentials::TokenRequestContext m_accessTokenContext;
mutable std::atomic<bool> m_invalidateToken = {false};
public:
/**
@ -610,6 +611,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
std::shared_lock<std::shared_timed_mutex> readLock(other.m_accessTokenMutex);
m_accessToken = other.m_accessToken;
m_accessTokenContext = other.m_accessTokenContext;
m_invalidateToken.store(other.m_invalidateToken.load());
}
void operator=(BearerTokenAuthenticationPolicy const&) = delete;

View File

@ -31,6 +31,7 @@ std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
auto result = AuthorizeAndSendRequest(request, nextPolicy, context);
{
auto const& response = *result;
m_invalidateToken = (response.GetStatusCode() == HttpStatusCode::Unauthorized);
auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response);
if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context))
{
@ -67,9 +68,10 @@ bool TokenNeedsRefresh(
Azure::Core::Credentials::AccessToken const& cachedToken,
Azure::Core::Credentials::TokenRequestContext const& cachedTokenRequestContext,
Azure::DateTime const& currentTime,
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext)
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext,
bool forceRefresh)
{
return newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
return forceRefresh || newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
|| newTokenRequestContext.Scopes != cachedTokenRequestContext.Scopes
|| currentTime > (cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration);
}
@ -91,7 +93,12 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
{
std::shared_lock<std::shared_timed_mutex> readLock(m_accessTokenMutex);
if (!TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
if (!TokenNeedsRefresh(
m_accessToken,
m_accessTokenContext,
currentTime,
tokenRequestContext,
m_invalidateToken))
{
ApplyBearerToken(request, m_accessToken);
return;
@ -100,10 +107,20 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
std::unique_lock<std::shared_timed_mutex> writeLock(m_accessTokenMutex);
// Check if token needs refresh for the second time in case another thread has just updated it.
if (TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
if (TokenNeedsRefresh(
m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext, m_invalidateToken))
{
m_accessToken = m_credential->GetToken(tokenRequestContext, context);
TokenRequestContext trcCopy = tokenRequestContext;
if (m_invalidateToken)
{
// Need to set this to invalidate the credential's token cache to ensure we fetch a new token
// on subsequent GetToken calls.
trcCopy.MinimumExpiration = DateTime::duration::max();
}
m_accessToken = m_credential->GetToken(trcCopy, context);
m_accessTokenContext = tokenRequestContext;
m_invalidateToken = false;
}
ApplyBearerToken(request, m_accessToken);

View File

@ -52,6 +52,37 @@ public:
return std::make_unique<TestTransportPolicy>(*this);
}
};
class TestTransportPolicyMultipleResponses final : public HttpPolicy {
private:
mutable int m_responsesCount = 0;
public:
std::unique_ptr<RawResponse> Send(Request&, NextHttpPolicy, Context const&) const override
{
if (m_responsesCount == 1)
{
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Unauthorized, "TestStatus");
}
if (m_responsesCount == 2)
{
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
}
if (m_responsesCount > 2)
{
EXPECT_TRUE(false);
}
m_responsesCount++;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
}
std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<TestTransportPolicyMultipleResponses>(*this);
}
};
} // namespace
TEST(BearerTokenAuthenticationPolicy, InitialGet)
@ -169,6 +200,66 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry)
}
}
TEST(BearerTokenAuthenticationPolicy, TokenInvalidatedAfterUnauth)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<HttpPolicy>> policies;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicyMultipleResponses>());
HttpPipeline pipeline(policies);
// The first request is successful, the token gets cached in the credential
{
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
}
}
// The second request returns unauthorized, the token should be invalidated
{
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
}
// We expect the next call to return a new token
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2");
}
}
}
TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry)
{
using namespace std::chrono_literals;