Re-use the selected credential that works for each subsequent request in DefaultAzureCredential by caching the chosen credential per instance. (#5142)

* Re-use the selected credential that works for each subsequent request in
DefaultAzureCredential by caching the chosen credential per instance.

* Update test to include per-instance caching validation.

* Addresss PR feedback and fix clang error on atomic assignment.

* Fix typo in CL and drop ifdef testing_build to investigate clang build
issue.

* Add double-colon in front of friend class.
This commit is contained in:
Ahson Khan 2023-11-09 19:33:26 -08:00 committed by GitHub
parent a6956c7639
commit 7632d67584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 8 deletions

View File

@ -4,6 +4,8 @@
### Features Added
- When one of the credentials within `DefaultAzureCredential` is successful, it gets re-used during all subsequent attempts to get the token.
### Breaking Changes
### Bugs Fixed

View File

@ -13,6 +13,10 @@
#include <memory>
#if defined(TESTING_BUILD)
class DefaultAzureCredential_CachingCredential_Test;
#endif
namespace Azure { namespace Identity {
namespace _detail {
class ChainedTokenCredentialImpl;
@ -37,6 +41,12 @@ namespace Azure { namespace Identity {
*
*/
class DefaultAzureCredential final : public Core::Credentials::TokenCredential {
#if defined(TESTING_BUILD)
// make tests classes friends to validate caching
friend class ::DefaultAzureCredential_CachingCredential_Test;
#endif
public:
/**
* @brief Constructs `%DefaultAzureCredential`.

View File

@ -32,8 +32,9 @@ AccessToken ChainedTokenCredential::GetToken(
ChainedTokenCredentialImpl::ChainedTokenCredentialImpl(
std::string const& credentialName,
ChainedTokenCredential::Sources&& sources)
: m_sources(std::move(sources))
ChainedTokenCredential::Sources&& sources,
bool reuseSuccessfulSource)
: m_sources(std::move(sources)), m_reuseSuccessfulSource(reuseSuccessfulSource)
{
auto const logLevel
= m_sources.empty() ? IdentityLog::Level::Warning : IdentityLog::Level::Informational;
@ -68,16 +69,52 @@ AccessToken ChainedTokenCredentialImpl::GetToken(
TokenRequestContext const& tokenRequestContext,
Context const& context) const
{
for (auto const& source : m_sources)
std::unique_lock<std::mutex> lock(m_sourcesMutex, std::defer_lock);
if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet)
{
lock.lock();
// Check again in case another thread already set the index, and unlock the mutex.
if (m_successfulSourceIndex != SuccessfulSourceNotSet)
{
lock.unlock();
}
}
std::size_t i = 0;
std::size_t end = m_sources.size();
if (m_successfulSourceIndex != SuccessfulSourceNotSet)
{
i = m_successfulSourceIndex;
end = m_successfulSourceIndex + 1;
}
for (; i < end; ++i)
{
auto& source = m_sources[i];
try
{
auto token = source->GetToken(tokenRequestContext, context);
IdentityLog::Write(
IdentityLog::Level::Informational,
credentialName + ": Successfully got token from " + source->GetCredentialName() + '.');
credentialName + ": Successfully got token from " + source->GetCredentialName()
+ (m_reuseSuccessfulSource ? ". Reuse this credential for subsequent calls." : "."));
// Log first before unlocking the mutex, so that the log message is not interleaved with
// other.
if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet)
{
IdentityLog::Write(
IdentityLog::Level::Verbose,
credentialName + ": Save this credential at index " + std::to_string(i)
+ " for subsequent calls.");
// We never re-update the selected credential index, after the first successful credential
// is found.
m_successfulSourceIndex = i;
lock.unlock();
}
return token;
}
catch (AuthenticationException const& e)

View File

@ -43,9 +43,12 @@ DefaultAzureCredential::DefaultAzureCredential(
auto const managedIdentityCred = std::make_shared<ManagedIdentityCredential>(options);
auto const azCliCred = std::make_shared<AzureCliCredential>(options);
// DefaultAzureCredential caches the selected credential, so that it can be reused on subsequent
// calls.
m_impl = std::make_unique<_detail::ChainedTokenCredentialImpl>(
GetCredentialName(),
ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred});
ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred},
true);
}
DefaultAzureCredential::~DefaultAzureCredential() = default;

View File

@ -5,13 +5,28 @@
#include "azure/identity/chained_token_credential.hpp"
#include <atomic>
#include <limits>
#include <mutex>
#if defined(TESTING_BUILD)
class DefaultAzureCredential_CachingCredential_Test;
#endif
namespace Azure { namespace Identity { namespace _detail {
class ChainedTokenCredentialImpl final {
#if defined(TESTING_BUILD)
// make tests classes friends to validate caching
friend class ::DefaultAzureCredential_CachingCredential_Test;
#endif
public:
ChainedTokenCredentialImpl(
std::string const& credentialName,
ChainedTokenCredential::Sources&& sources);
ChainedTokenCredential::Sources&& sources,
bool reuseSuccessfulSource = false);
Core::Credentials::AccessToken GetToken(
std::string const& credentialName,
@ -20,6 +35,13 @@ namespace Azure { namespace Identity { namespace _detail {
private:
ChainedTokenCredential::Sources m_sources;
mutable std::mutex m_sourcesMutex;
// Used as a sentinel value to indicate that the index of the source being used for future calls
// hasn't been found yet.
constexpr static std::size_t SuccessfulSourceNotSet = std::numeric_limits<std::size_t>::max();
// This needs to be atomic so that sentinel comparison is thread safe.
mutable std::atomic<std::size_t> m_successfulSourceIndex = {SuccessfulSourceNotSet};
bool m_reuseSuccessfulSource;
};
}}} // namespace Azure::Identity::_detail

View File

@ -87,6 +87,17 @@ TEST(ChainedTokenCredential, ErrorThenSuccess)
EXPECT_TRUE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
// We expect chained token credential will NOT cache the selected credential which was successful
// and retry each one from the start.
c1->WasInvoked = false;
c1->WasInvoked = false;
token = cred.GetToken({}, {});
EXPECT_EQ(token.Token, "Token2");
EXPECT_TRUE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
}
TEST(ChainedTokenCredential, AllErrors)

View File

@ -6,15 +6,44 @@
#include <azure/core/diagnostics/logger.hpp>
#include <../src/private/chained_token_credential_impl.hpp>
#include <gtest/gtest.h>
using Azure::Identity::DefaultAzureCredential;
using Azure::Core::Context;
using Azure::Core::Credentials::AccessToken;
using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenCredential;
using Azure::Core::Credentials::TokenCredentialOptions;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Diagnostics::Logger;
using Azure::Identity::Test::_detail::CredentialTestHelper;
namespace {
class TestCredential : public TokenCredential {
private:
std::string m_token;
public:
TestCredential(std::string token = "") : TokenCredential("TestCredential"), m_token(token) {}
mutable bool WasInvoked = false;
AccessToken GetToken(TokenRequestContext const&, Context const&) const override
{
WasInvoked = true;
if (m_token.empty())
{
throw AuthenticationException("Test Error");
}
AccessToken token;
token.Token = m_token;
return token;
}
};
} // namespace
TEST(DefaultAzureCredential, GetCredentialName)
@ -40,6 +69,75 @@ TEST(DefaultAzureCredential, GetCredentialName)
EXPECT_EQ(cred.GetCredentialName(), "DefaultAzureCredential");
}
TEST(DefaultAzureCredential, CachingCredential)
{
auto c1 = std::make_shared<TestCredential>();
auto c2 = std::make_shared<TestCredential>("Token2");
DefaultAzureCredential cred;
cred.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true);
EXPECT_FALSE(c1->WasInvoked);
EXPECT_FALSE(c2->WasInvoked);
auto token = cred.GetToken({}, {});
EXPECT_EQ(token.Token, "Token2");
EXPECT_TRUE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
// We expect default azure credential to cache the selected credential which was successful
// and only try that one, rather than going through the entire list again.
c1->WasInvoked = false;
c1->WasInvoked = false;
token = cred.GetToken({}, {});
EXPECT_EQ(token.Token, "Token2");
EXPECT_FALSE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
// Only the 2nd credential in the list should get invoked, which is c1, since that's the cached
// index.
c1->WasInvoked = false;
c2->WasInvoked = false;
cred.m_impl->m_sources = Azure::Identity::ChainedTokenCredential::Sources{c2, c1, c2};
// We don't expect c2 to ever be used here.
EXPECT_THROW(static_cast<void>(cred.GetToken({}, {})), AuthenticationException);
EXPECT_TRUE(c1->WasInvoked);
EXPECT_FALSE(c2->WasInvoked);
// Caching is per instance of the DefaultAzureCredential and not global.
c1->WasInvoked = false;
c2->WasInvoked = false;
DefaultAzureCredential cred1;
cred1.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true);
DefaultAzureCredential cred2;
cred2.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c2, c1}, true);
// The first credential in the list, c2, got called and cached on cred2.
token = cred2.GetToken({}, {});
EXPECT_EQ(token.Token, "Token2");
EXPECT_FALSE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
// cred1 is unaffected by cred2 and both c1 and c2 are called, in order.
token = cred1.GetToken({}, {});
EXPECT_EQ(token.Token, "Token2");
EXPECT_TRUE(c1->WasInvoked);
EXPECT_TRUE(c2->WasInvoked);
}
TEST(DefaultAzureCredential, LogMessages)
{
using LogMsgVec = std::vector<std::pair<Logger::Level, std::string>>;
@ -150,12 +248,18 @@ TEST(DefaultAzureCredential, LogMessages)
EXPECT_EQ(
log.size(),
LogMsgVec::size_type(4)); // Request and retry policies will get their messages here as well.
LogMsgVec::size_type(5)); // Request and retry policies will get their messages here as well.
EXPECT_EQ(log[3].first, Logger::Level::Informational);
EXPECT_EQ(
log[3].second,
"Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential.");
"Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential. Reuse "
"this credential for subsequent calls.");
EXPECT_EQ(log[4].first, Logger::Level::Verbose);
EXPECT_EQ(
log[4].second,
"Identity: DefaultAzureCredential: Save this credential at index 0 for subsequent calls.");
Logger::SetListener(nullptr);
}