Set token to be expired if response comes back as unauthorized, within BeareTokenAuthenticationPolicy. (#6151)
* Set token to be expired if response comes back as unauthorized. * Add CL entry. * Update CL. * Use trc MinimumExpiration to invalidate the credential's token cache. * Add test. * Address PR feedback. * Remove comment as it is no longer relevant. * Use initializer list syntax to see if posix compilers are okay with that. * Keep the bool field as non-atomic. * Revert "Keep the bool field as non-atomic." This reverts commit 1b8c7622d5234b010bb0a4eb5db8a436de5a2adf.
This commit is contained in:
parent
c168d736dd
commit
064fcad72f
@ -9,6 +9,7 @@
|
||||
### Bugs Fixed
|
||||
|
||||
- Fixed warning for an unused function in curl.cpp when building the SDK using a version of libcurl older than 7.77.0.
|
||||
- Invalidate the token cache within `BearerTokenAuthenticationPolicy` whenever a token request comes back with a 401 response.
|
||||
|
||||
### Other Changes
|
||||
|
||||
|
||||
@ -576,6 +576,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
|
||||
mutable Credentials::AccessToken m_accessToken;
|
||||
mutable std::shared_timed_mutex m_accessTokenMutex;
|
||||
mutable Credentials::TokenRequestContext m_accessTokenContext;
|
||||
mutable std::atomic<bool> m_invalidateToken = {false};
|
||||
|
||||
public:
|
||||
/**
|
||||
@ -610,6 +611,7 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
|
||||
std::shared_lock<std::shared_timed_mutex> readLock(other.m_accessTokenMutex);
|
||||
m_accessToken = other.m_accessToken;
|
||||
m_accessTokenContext = other.m_accessTokenContext;
|
||||
m_invalidateToken.store(other.m_invalidateToken.load());
|
||||
}
|
||||
|
||||
void operator=(BearerTokenAuthenticationPolicy const&) = delete;
|
||||
|
||||
@ -31,6 +31,7 @@ std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
|
||||
auto result = AuthorizeAndSendRequest(request, nextPolicy, context);
|
||||
{
|
||||
auto const& response = *result;
|
||||
m_invalidateToken = (response.GetStatusCode() == HttpStatusCode::Unauthorized);
|
||||
auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response);
|
||||
if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context))
|
||||
{
|
||||
@ -67,9 +68,10 @@ bool TokenNeedsRefresh(
|
||||
Azure::Core::Credentials::AccessToken const& cachedToken,
|
||||
Azure::Core::Credentials::TokenRequestContext const& cachedTokenRequestContext,
|
||||
Azure::DateTime const& currentTime,
|
||||
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext)
|
||||
Azure::Core::Credentials::TokenRequestContext const& newTokenRequestContext,
|
||||
bool forceRefresh)
|
||||
{
|
||||
return newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
|
||||
return forceRefresh || newTokenRequestContext.TenantId != cachedTokenRequestContext.TenantId
|
||||
|| newTokenRequestContext.Scopes != cachedTokenRequestContext.Scopes
|
||||
|| currentTime > (cachedToken.ExpiresOn - newTokenRequestContext.MinimumExpiration);
|
||||
}
|
||||
@ -91,7 +93,12 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
|
||||
|
||||
{
|
||||
std::shared_lock<std::shared_timed_mutex> readLock(m_accessTokenMutex);
|
||||
if (!TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
|
||||
if (!TokenNeedsRefresh(
|
||||
m_accessToken,
|
||||
m_accessTokenContext,
|
||||
currentTime,
|
||||
tokenRequestContext,
|
||||
m_invalidateToken))
|
||||
{
|
||||
ApplyBearerToken(request, m_accessToken);
|
||||
return;
|
||||
@ -100,10 +107,20 @@ void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
|
||||
|
||||
std::unique_lock<std::shared_timed_mutex> writeLock(m_accessTokenMutex);
|
||||
// Check if token needs refresh for the second time in case another thread has just updated it.
|
||||
if (TokenNeedsRefresh(m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext))
|
||||
if (TokenNeedsRefresh(
|
||||
m_accessToken, m_accessTokenContext, currentTime, tokenRequestContext, m_invalidateToken))
|
||||
{
|
||||
m_accessToken = m_credential->GetToken(tokenRequestContext, context);
|
||||
TokenRequestContext trcCopy = tokenRequestContext;
|
||||
if (m_invalidateToken)
|
||||
{
|
||||
// Need to set this to invalidate the credential's token cache to ensure we fetch a new token
|
||||
// on subsequent GetToken calls.
|
||||
trcCopy.MinimumExpiration = DateTime::duration::max();
|
||||
}
|
||||
|
||||
m_accessToken = m_credential->GetToken(trcCopy, context);
|
||||
m_accessTokenContext = tokenRequestContext;
|
||||
m_invalidateToken = false;
|
||||
}
|
||||
|
||||
ApplyBearerToken(request, m_accessToken);
|
||||
|
||||
@ -52,6 +52,37 @@ public:
|
||||
return std::make_unique<TestTransportPolicy>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
class TestTransportPolicyMultipleResponses final : public HttpPolicy {
|
||||
private:
|
||||
mutable int m_responsesCount = 0;
|
||||
|
||||
public:
|
||||
std::unique_ptr<RawResponse> Send(Request&, NextHttpPolicy, Context const&) const override
|
||||
{
|
||||
if (m_responsesCount == 1)
|
||||
{
|
||||
m_responsesCount++;
|
||||
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Unauthorized, "TestStatus");
|
||||
}
|
||||
if (m_responsesCount == 2)
|
||||
{
|
||||
m_responsesCount++;
|
||||
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
|
||||
}
|
||||
if (m_responsesCount > 2)
|
||||
{
|
||||
EXPECT_TRUE(false);
|
||||
}
|
||||
m_responsesCount++;
|
||||
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
|
||||
}
|
||||
|
||||
std::unique_ptr<HttpPolicy> Clone() const override
|
||||
{
|
||||
return std::make_unique<TestTransportPolicyMultipleResponses>(*this);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(BearerTokenAuthenticationPolicy, InitialGet)
|
||||
@ -169,6 +200,66 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BearerTokenAuthenticationPolicy, TokenInvalidatedAfterUnauth)
|
||||
{
|
||||
using namespace std::chrono_literals;
|
||||
auto accessToken = std::make_shared<AccessToken>();
|
||||
|
||||
std::vector<std::unique_ptr<HttpPolicy>> policies;
|
||||
|
||||
TokenRequestContext tokenRequestContext;
|
||||
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
|
||||
|
||||
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
|
||||
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
|
||||
|
||||
policies.emplace_back(std::make_unique<TestTransportPolicyMultipleResponses>());
|
||||
|
||||
HttpPipeline pipeline(policies);
|
||||
|
||||
// The first request is successful, the token gets cached in the credential
|
||||
{
|
||||
Request request(HttpMethod::Get, Url("https://www.azure.com"));
|
||||
|
||||
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h};
|
||||
|
||||
pipeline.Send(request, Context());
|
||||
|
||||
{
|
||||
auto const headers = request.GetHeaders();
|
||||
auto const authHeader = headers.find("authorization");
|
||||
EXPECT_NE(authHeader, headers.end());
|
||||
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
|
||||
}
|
||||
}
|
||||
|
||||
// The second request returns unauthorized, the token should be invalidated
|
||||
{
|
||||
Request request(HttpMethod::Get, Url("https://www.azure.com"));
|
||||
|
||||
*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};
|
||||
|
||||
pipeline.Send(request, Context());
|
||||
|
||||
{
|
||||
auto const headers = request.GetHeaders();
|
||||
auto const authHeader = headers.find("authorization");
|
||||
EXPECT_NE(authHeader, headers.end());
|
||||
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
|
||||
}
|
||||
|
||||
// We expect the next call to return a new token
|
||||
pipeline.Send(request, Context());
|
||||
|
||||
{
|
||||
auto const headers = request.GetHeaders();
|
||||
auto const authHeader = headers.find("authorization");
|
||||
EXPECT_NE(authHeader, headers.end());
|
||||
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry)
|
||||
{
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user