Re-use the selected credential that works for each subsequent request in DefaultAzureCredential by caching the chosen credential per instance. (#5142)
* Re-use the selected credential that works for each subsequent request in DefaultAzureCredential by caching the chosen credential per instance. * Update test to include per-instance caching validation. * Addresss PR feedback and fix clang error on atomic assignment. * Fix typo in CL and drop ifdef testing_build to investigate clang build issue. * Add double-colon in front of friend class.
This commit is contained in:
parent
a6956c7639
commit
7632d67584
@ -4,6 +4,8 @@
|
||||
|
||||
### Features Added
|
||||
|
||||
- When one of the credentials within `DefaultAzureCredential` is successful, it gets re-used during all subsequent attempts to get the token.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
### Bugs Fixed
|
||||
|
||||
@ -13,6 +13,10 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if defined(TESTING_BUILD)
|
||||
class DefaultAzureCredential_CachingCredential_Test;
|
||||
#endif
|
||||
|
||||
namespace Azure { namespace Identity {
|
||||
namespace _detail {
|
||||
class ChainedTokenCredentialImpl;
|
||||
@ -37,6 +41,12 @@ namespace Azure { namespace Identity {
|
||||
*
|
||||
*/
|
||||
class DefaultAzureCredential final : public Core::Credentials::TokenCredential {
|
||||
|
||||
#if defined(TESTING_BUILD)
|
||||
// make tests classes friends to validate caching
|
||||
friend class ::DefaultAzureCredential_CachingCredential_Test;
|
||||
#endif
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs `%DefaultAzureCredential`.
|
||||
|
||||
@ -32,8 +32,9 @@ AccessToken ChainedTokenCredential::GetToken(
|
||||
|
||||
ChainedTokenCredentialImpl::ChainedTokenCredentialImpl(
|
||||
std::string const& credentialName,
|
||||
ChainedTokenCredential::Sources&& sources)
|
||||
: m_sources(std::move(sources))
|
||||
ChainedTokenCredential::Sources&& sources,
|
||||
bool reuseSuccessfulSource)
|
||||
: m_sources(std::move(sources)), m_reuseSuccessfulSource(reuseSuccessfulSource)
|
||||
{
|
||||
auto const logLevel
|
||||
= m_sources.empty() ? IdentityLog::Level::Warning : IdentityLog::Level::Informational;
|
||||
@ -68,16 +69,52 @@ AccessToken ChainedTokenCredentialImpl::GetToken(
|
||||
TokenRequestContext const& tokenRequestContext,
|
||||
Context const& context) const
|
||||
{
|
||||
for (auto const& source : m_sources)
|
||||
std::unique_lock<std::mutex> lock(m_sourcesMutex, std::defer_lock);
|
||||
|
||||
if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet)
|
||||
{
|
||||
lock.lock();
|
||||
// Check again in case another thread already set the index, and unlock the mutex.
|
||||
if (m_successfulSourceIndex != SuccessfulSourceNotSet)
|
||||
{
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t i = 0;
|
||||
std::size_t end = m_sources.size();
|
||||
if (m_successfulSourceIndex != SuccessfulSourceNotSet)
|
||||
{
|
||||
i = m_successfulSourceIndex;
|
||||
end = m_successfulSourceIndex + 1;
|
||||
}
|
||||
|
||||
for (; i < end; ++i)
|
||||
{
|
||||
auto& source = m_sources[i];
|
||||
try
|
||||
{
|
||||
auto token = source->GetToken(tokenRequestContext, context);
|
||||
|
||||
IdentityLog::Write(
|
||||
IdentityLog::Level::Informational,
|
||||
credentialName + ": Successfully got token from " + source->GetCredentialName() + '.');
|
||||
credentialName + ": Successfully got token from " + source->GetCredentialName()
|
||||
+ (m_reuseSuccessfulSource ? ". Reuse this credential for subsequent calls." : "."));
|
||||
|
||||
// Log first before unlocking the mutex, so that the log message is not interleaved with
|
||||
// other.
|
||||
if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet)
|
||||
{
|
||||
IdentityLog::Write(
|
||||
IdentityLog::Level::Verbose,
|
||||
credentialName + ": Save this credential at index " + std::to_string(i)
|
||||
+ " for subsequent calls.");
|
||||
|
||||
// We never re-update the selected credential index, after the first successful credential
|
||||
// is found.
|
||||
m_successfulSourceIndex = i;
|
||||
lock.unlock();
|
||||
}
|
||||
return token;
|
||||
}
|
||||
catch (AuthenticationException const& e)
|
||||
|
||||
@ -43,9 +43,12 @@ DefaultAzureCredential::DefaultAzureCredential(
|
||||
auto const managedIdentityCred = std::make_shared<ManagedIdentityCredential>(options);
|
||||
auto const azCliCred = std::make_shared<AzureCliCredential>(options);
|
||||
|
||||
// DefaultAzureCredential caches the selected credential, so that it can be reused on subsequent
|
||||
// calls.
|
||||
m_impl = std::make_unique<_detail::ChainedTokenCredentialImpl>(
|
||||
GetCredentialName(),
|
||||
ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred});
|
||||
ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred},
|
||||
true);
|
||||
}
|
||||
|
||||
DefaultAzureCredential::~DefaultAzureCredential() = default;
|
||||
|
||||
@ -5,13 +5,28 @@
|
||||
|
||||
#include "azure/identity/chained_token_credential.hpp"
|
||||
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
|
||||
#if defined(TESTING_BUILD)
|
||||
class DefaultAzureCredential_CachingCredential_Test;
|
||||
#endif
|
||||
|
||||
namespace Azure { namespace Identity { namespace _detail {
|
||||
|
||||
class ChainedTokenCredentialImpl final {
|
||||
|
||||
#if defined(TESTING_BUILD)
|
||||
// make tests classes friends to validate caching
|
||||
friend class ::DefaultAzureCredential_CachingCredential_Test;
|
||||
#endif
|
||||
|
||||
public:
|
||||
ChainedTokenCredentialImpl(
|
||||
std::string const& credentialName,
|
||||
ChainedTokenCredential::Sources&& sources);
|
||||
ChainedTokenCredential::Sources&& sources,
|
||||
bool reuseSuccessfulSource = false);
|
||||
|
||||
Core::Credentials::AccessToken GetToken(
|
||||
std::string const& credentialName,
|
||||
@ -20,6 +35,13 @@ namespace Azure { namespace Identity { namespace _detail {
|
||||
|
||||
private:
|
||||
ChainedTokenCredential::Sources m_sources;
|
||||
mutable std::mutex m_sourcesMutex;
|
||||
// Used as a sentinel value to indicate that the index of the source being used for future calls
|
||||
// hasn't been found yet.
|
||||
constexpr static std::size_t SuccessfulSourceNotSet = std::numeric_limits<std::size_t>::max();
|
||||
// This needs to be atomic so that sentinel comparison is thread safe.
|
||||
mutable std::atomic<std::size_t> m_successfulSourceIndex = {SuccessfulSourceNotSet};
|
||||
bool m_reuseSuccessfulSource;
|
||||
};
|
||||
|
||||
}}} // namespace Azure::Identity::_detail
|
||||
|
||||
@ -87,6 +87,17 @@ TEST(ChainedTokenCredential, ErrorThenSuccess)
|
||||
|
||||
EXPECT_TRUE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
|
||||
// We expect chained token credential will NOT cache the selected credential which was successful
|
||||
// and retry each one from the start.
|
||||
c1->WasInvoked = false;
|
||||
c1->WasInvoked = false;
|
||||
|
||||
token = cred.GetToken({}, {});
|
||||
EXPECT_EQ(token.Token, "Token2");
|
||||
|
||||
EXPECT_TRUE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
}
|
||||
|
||||
TEST(ChainedTokenCredential, AllErrors)
|
||||
|
||||
@ -6,15 +6,44 @@
|
||||
|
||||
#include <azure/core/diagnostics/logger.hpp>
|
||||
|
||||
#include <../src/private/chained_token_credential_impl.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using Azure::Identity::DefaultAzureCredential;
|
||||
|
||||
using Azure::Core::Context;
|
||||
using Azure::Core::Credentials::AccessToken;
|
||||
using Azure::Core::Credentials::AuthenticationException;
|
||||
using Azure::Core::Credentials::TokenCredential;
|
||||
using Azure::Core::Credentials::TokenCredentialOptions;
|
||||
using Azure::Core::Credentials::TokenRequestContext;
|
||||
using Azure::Core::Diagnostics::Logger;
|
||||
using Azure::Identity::Test::_detail::CredentialTestHelper;
|
||||
|
||||
namespace {
|
||||
class TestCredential : public TokenCredential {
|
||||
private:
|
||||
std::string m_token;
|
||||
|
||||
public:
|
||||
TestCredential(std::string token = "") : TokenCredential("TestCredential"), m_token(token) {}
|
||||
|
||||
mutable bool WasInvoked = false;
|
||||
|
||||
AccessToken GetToken(TokenRequestContext const&, Context const&) const override
|
||||
{
|
||||
WasInvoked = true;
|
||||
|
||||
if (m_token.empty())
|
||||
{
|
||||
throw AuthenticationException("Test Error");
|
||||
}
|
||||
|
||||
AccessToken token;
|
||||
token.Token = m_token;
|
||||
return token;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(DefaultAzureCredential, GetCredentialName)
|
||||
@ -40,6 +69,75 @@ TEST(DefaultAzureCredential, GetCredentialName)
|
||||
EXPECT_EQ(cred.GetCredentialName(), "DefaultAzureCredential");
|
||||
}
|
||||
|
||||
TEST(DefaultAzureCredential, CachingCredential)
|
||||
{
|
||||
auto c1 = std::make_shared<TestCredential>();
|
||||
auto c2 = std::make_shared<TestCredential>("Token2");
|
||||
DefaultAzureCredential cred;
|
||||
|
||||
cred.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
|
||||
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true);
|
||||
|
||||
EXPECT_FALSE(c1->WasInvoked);
|
||||
EXPECT_FALSE(c2->WasInvoked);
|
||||
|
||||
auto token = cred.GetToken({}, {});
|
||||
EXPECT_EQ(token.Token, "Token2");
|
||||
|
||||
EXPECT_TRUE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
|
||||
// We expect default azure credential to cache the selected credential which was successful
|
||||
// and only try that one, rather than going through the entire list again.
|
||||
c1->WasInvoked = false;
|
||||
c1->WasInvoked = false;
|
||||
|
||||
token = cred.GetToken({}, {});
|
||||
EXPECT_EQ(token.Token, "Token2");
|
||||
|
||||
EXPECT_FALSE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
|
||||
// Only the 2nd credential in the list should get invoked, which is c1, since that's the cached
|
||||
// index.
|
||||
c1->WasInvoked = false;
|
||||
c2->WasInvoked = false;
|
||||
|
||||
cred.m_impl->m_sources = Azure::Identity::ChainedTokenCredential::Sources{c2, c1, c2};
|
||||
|
||||
// We don't expect c2 to ever be used here.
|
||||
EXPECT_THROW(static_cast<void>(cred.GetToken({}, {})), AuthenticationException);
|
||||
|
||||
EXPECT_TRUE(c1->WasInvoked);
|
||||
EXPECT_FALSE(c2->WasInvoked);
|
||||
|
||||
// Caching is per instance of the DefaultAzureCredential and not global.
|
||||
c1->WasInvoked = false;
|
||||
c2->WasInvoked = false;
|
||||
|
||||
DefaultAzureCredential cred1;
|
||||
cred1.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
|
||||
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true);
|
||||
|
||||
DefaultAzureCredential cred2;
|
||||
cred2.m_impl = std::make_unique<Azure::Identity::_detail::ChainedTokenCredentialImpl>(
|
||||
"Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c2, c1}, true);
|
||||
|
||||
// The first credential in the list, c2, got called and cached on cred2.
|
||||
token = cred2.GetToken({}, {});
|
||||
EXPECT_EQ(token.Token, "Token2");
|
||||
|
||||
EXPECT_FALSE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
|
||||
// cred1 is unaffected by cred2 and both c1 and c2 are called, in order.
|
||||
token = cred1.GetToken({}, {});
|
||||
EXPECT_EQ(token.Token, "Token2");
|
||||
|
||||
EXPECT_TRUE(c1->WasInvoked);
|
||||
EXPECT_TRUE(c2->WasInvoked);
|
||||
}
|
||||
|
||||
TEST(DefaultAzureCredential, LogMessages)
|
||||
{
|
||||
using LogMsgVec = std::vector<std::pair<Logger::Level, std::string>>;
|
||||
@ -150,12 +248,18 @@ TEST(DefaultAzureCredential, LogMessages)
|
||||
|
||||
EXPECT_EQ(
|
||||
log.size(),
|
||||
LogMsgVec::size_type(4)); // Request and retry policies will get their messages here as well.
|
||||
LogMsgVec::size_type(5)); // Request and retry policies will get their messages here as well.
|
||||
|
||||
EXPECT_EQ(log[3].first, Logger::Level::Informational);
|
||||
EXPECT_EQ(
|
||||
log[3].second,
|
||||
"Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential.");
|
||||
"Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential. Reuse "
|
||||
"this credential for subsequent calls.");
|
||||
|
||||
EXPECT_EQ(log[4].first, Logger::Level::Verbose);
|
||||
EXPECT_EQ(
|
||||
log[4].second,
|
||||
"Identity: DefaultAzureCredential: Save this credential at index 0 for subsequent calls.");
|
||||
|
||||
Logger::SetListener(nullptr);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user