diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 8658986cb..d89a5b2f6 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- When one of the credentials within `DefaultAzureCredential` is successful, it gets re-used during all subsequent attempts to get the token. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp index 52f1707af..48ee1eaf0 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp @@ -13,6 +13,10 @@ #include +#if defined(TESTING_BUILD) +class DefaultAzureCredential_CachingCredential_Test; +#endif + namespace Azure { namespace Identity { namespace _detail { class ChainedTokenCredentialImpl; @@ -37,6 +41,12 @@ namespace Azure { namespace Identity { * */ class DefaultAzureCredential final : public Core::Credentials::TokenCredential { + +#if defined(TESTING_BUILD) + // make tests classes friends to validate caching + friend class ::DefaultAzureCredential_CachingCredential_Test; +#endif + public: /** * @brief Constructs `%DefaultAzureCredential`. diff --git a/sdk/identity/azure-identity/src/chained_token_credential.cpp b/sdk/identity/azure-identity/src/chained_token_credential.cpp index 645cd0f52..8badd7ceb 100644 --- a/sdk/identity/azure-identity/src/chained_token_credential.cpp +++ b/sdk/identity/azure-identity/src/chained_token_credential.cpp @@ -32,8 +32,9 @@ AccessToken ChainedTokenCredential::GetToken( ChainedTokenCredentialImpl::ChainedTokenCredentialImpl( std::string const& credentialName, - ChainedTokenCredential::Sources&& sources) - : m_sources(std::move(sources)) + ChainedTokenCredential::Sources&& sources, + bool reuseSuccessfulSource) + : m_sources(std::move(sources)), m_reuseSuccessfulSource(reuseSuccessfulSource) { auto const logLevel = m_sources.empty() ? IdentityLog::Level::Warning : IdentityLog::Level::Informational; @@ -68,16 +69,52 @@ AccessToken ChainedTokenCredentialImpl::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - for (auto const& source : m_sources) + std::unique_lock lock(m_sourcesMutex, std::defer_lock); + + if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet) { + lock.lock(); + // Check again in case another thread already set the index, and unlock the mutex. + if (m_successfulSourceIndex != SuccessfulSourceNotSet) + { + lock.unlock(); + } + } + + std::size_t i = 0; + std::size_t end = m_sources.size(); + if (m_successfulSourceIndex != SuccessfulSourceNotSet) + { + i = m_successfulSourceIndex; + end = m_successfulSourceIndex + 1; + } + + for (; i < end; ++i) + { + auto& source = m_sources[i]; try { auto token = source->GetToken(tokenRequestContext, context); IdentityLog::Write( IdentityLog::Level::Informational, - credentialName + ": Successfully got token from " + source->GetCredentialName() + '.'); + credentialName + ": Successfully got token from " + source->GetCredentialName() + + (m_reuseSuccessfulSource ? ". Reuse this credential for subsequent calls." : ".")); + // Log first before unlocking the mutex, so that the log message is not interleaved with + // other. + if (m_reuseSuccessfulSource && m_successfulSourceIndex == SuccessfulSourceNotSet) + { + IdentityLog::Write( + IdentityLog::Level::Verbose, + credentialName + ": Save this credential at index " + std::to_string(i) + + " for subsequent calls."); + + // We never re-update the selected credential index, after the first successful credential + // is found. + m_successfulSourceIndex = i; + lock.unlock(); + } return token; } catch (AuthenticationException const& e) diff --git a/sdk/identity/azure-identity/src/default_azure_credential.cpp b/sdk/identity/azure-identity/src/default_azure_credential.cpp index 26d1e06af..cda1bd5d8 100644 --- a/sdk/identity/azure-identity/src/default_azure_credential.cpp +++ b/sdk/identity/azure-identity/src/default_azure_credential.cpp @@ -43,9 +43,12 @@ DefaultAzureCredential::DefaultAzureCredential( auto const managedIdentityCred = std::make_shared(options); auto const azCliCred = std::make_shared(options); + // DefaultAzureCredential caches the selected credential, so that it can be reused on subsequent + // calls. m_impl = std::make_unique<_detail::ChainedTokenCredentialImpl>( GetCredentialName(), - ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred}); + ChainedTokenCredential::Sources{envCred, wiCred, managedIdentityCred, azCliCred}, + true); } DefaultAzureCredential::~DefaultAzureCredential() = default; diff --git a/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp b/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp index 182a5c846..5b285da82 100644 --- a/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp +++ b/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp @@ -5,13 +5,28 @@ #include "azure/identity/chained_token_credential.hpp" +#include +#include +#include + +#if defined(TESTING_BUILD) +class DefaultAzureCredential_CachingCredential_Test; +#endif + namespace Azure { namespace Identity { namespace _detail { class ChainedTokenCredentialImpl final { + +#if defined(TESTING_BUILD) + // make tests classes friends to validate caching + friend class ::DefaultAzureCredential_CachingCredential_Test; +#endif + public: ChainedTokenCredentialImpl( std::string const& credentialName, - ChainedTokenCredential::Sources&& sources); + ChainedTokenCredential::Sources&& sources, + bool reuseSuccessfulSource = false); Core::Credentials::AccessToken GetToken( std::string const& credentialName, @@ -20,6 +35,13 @@ namespace Azure { namespace Identity { namespace _detail { private: ChainedTokenCredential::Sources m_sources; + mutable std::mutex m_sourcesMutex; + // Used as a sentinel value to indicate that the index of the source being used for future calls + // hasn't been found yet. + constexpr static std::size_t SuccessfulSourceNotSet = std::numeric_limits::max(); + // This needs to be atomic so that sentinel comparison is thread safe. + mutable std::atomic m_successfulSourceIndex = {SuccessfulSourceNotSet}; + bool m_reuseSuccessfulSource; }; }}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/test/ut/chained_token_credential_test.cpp b/sdk/identity/azure-identity/test/ut/chained_token_credential_test.cpp index 9a48e3a0f..24f9efe5b 100644 --- a/sdk/identity/azure-identity/test/ut/chained_token_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/chained_token_credential_test.cpp @@ -87,6 +87,17 @@ TEST(ChainedTokenCredential, ErrorThenSuccess) EXPECT_TRUE(c1->WasInvoked); EXPECT_TRUE(c2->WasInvoked); + + // We expect chained token credential will NOT cache the selected credential which was successful + // and retry each one from the start. + c1->WasInvoked = false; + c1->WasInvoked = false; + + token = cred.GetToken({}, {}); + EXPECT_EQ(token.Token, "Token2"); + + EXPECT_TRUE(c1->WasInvoked); + EXPECT_TRUE(c2->WasInvoked); } TEST(ChainedTokenCredential, AllErrors) diff --git a/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp b/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp index 9dede427b..13c55b680 100644 --- a/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp @@ -6,15 +6,44 @@ #include +#include <../src/private/chained_token_credential_impl.hpp> #include using Azure::Identity::DefaultAzureCredential; +using Azure::Core::Context; +using Azure::Core::Credentials::AccessToken; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredential; using Azure::Core::Credentials::TokenCredentialOptions; +using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Diagnostics::Logger; using Azure::Identity::Test::_detail::CredentialTestHelper; namespace { +class TestCredential : public TokenCredential { +private: + std::string m_token; + +public: + TestCredential(std::string token = "") : TokenCredential("TestCredential"), m_token(token) {} + + mutable bool WasInvoked = false; + + AccessToken GetToken(TokenRequestContext const&, Context const&) const override + { + WasInvoked = true; + + if (m_token.empty()) + { + throw AuthenticationException("Test Error"); + } + + AccessToken token; + token.Token = m_token; + return token; + } +}; } // namespace TEST(DefaultAzureCredential, GetCredentialName) @@ -40,6 +69,75 @@ TEST(DefaultAzureCredential, GetCredentialName) EXPECT_EQ(cred.GetCredentialName(), "DefaultAzureCredential"); } +TEST(DefaultAzureCredential, CachingCredential) +{ + auto c1 = std::make_shared(); + auto c2 = std::make_shared("Token2"); + DefaultAzureCredential cred; + + cred.m_impl = std::make_unique( + "Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true); + + EXPECT_FALSE(c1->WasInvoked); + EXPECT_FALSE(c2->WasInvoked); + + auto token = cred.GetToken({}, {}); + EXPECT_EQ(token.Token, "Token2"); + + EXPECT_TRUE(c1->WasInvoked); + EXPECT_TRUE(c2->WasInvoked); + + // We expect default azure credential to cache the selected credential which was successful + // and only try that one, rather than going through the entire list again. + c1->WasInvoked = false; + c1->WasInvoked = false; + + token = cred.GetToken({}, {}); + EXPECT_EQ(token.Token, "Token2"); + + EXPECT_FALSE(c1->WasInvoked); + EXPECT_TRUE(c2->WasInvoked); + + // Only the 2nd credential in the list should get invoked, which is c1, since that's the cached + // index. + c1->WasInvoked = false; + c2->WasInvoked = false; + + cred.m_impl->m_sources = Azure::Identity::ChainedTokenCredential::Sources{c2, c1, c2}; + + // We don't expect c2 to ever be used here. + EXPECT_THROW(static_cast(cred.GetToken({}, {})), AuthenticationException); + + EXPECT_TRUE(c1->WasInvoked); + EXPECT_FALSE(c2->WasInvoked); + + // Caching is per instance of the DefaultAzureCredential and not global. + c1->WasInvoked = false; + c2->WasInvoked = false; + + DefaultAzureCredential cred1; + cred1.m_impl = std::make_unique( + "Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c1, c2}, true); + + DefaultAzureCredential cred2; + cred2.m_impl = std::make_unique( + "Test DAC", Azure::Identity::ChainedTokenCredential::Sources{c2, c1}, true); + + // The first credential in the list, c2, got called and cached on cred2. + token = cred2.GetToken({}, {}); + EXPECT_EQ(token.Token, "Token2"); + + EXPECT_FALSE(c1->WasInvoked); + EXPECT_TRUE(c2->WasInvoked); + + // cred1 is unaffected by cred2 and both c1 and c2 are called, in order. + token = cred1.GetToken({}, {}); + EXPECT_EQ(token.Token, "Token2"); + + EXPECT_TRUE(c1->WasInvoked); + EXPECT_TRUE(c2->WasInvoked); +} + TEST(DefaultAzureCredential, LogMessages) { using LogMsgVec = std::vector>; @@ -150,12 +248,18 @@ TEST(DefaultAzureCredential, LogMessages) EXPECT_EQ( log.size(), - LogMsgVec::size_type(4)); // Request and retry policies will get their messages here as well. + LogMsgVec::size_type(5)); // Request and retry policies will get their messages here as well. EXPECT_EQ(log[3].first, Logger::Level::Informational); EXPECT_EQ( log[3].second, - "Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential."); + "Identity: DefaultAzureCredential: Successfully got token from EnvironmentCredential. Reuse " + "this credential for subsequent calls."); + + EXPECT_EQ(log[4].first, Logger::Level::Verbose); + EXPECT_EQ( + log[4].second, + "Identity: DefaultAzureCredential: Save this credential at index 0 for subsequent calls."); Logger::SetListener(nullptr); }