From 68e9dc9ecb6980e04a5f779c806e97f8becae974 Mon Sep 17 00:00:00 2001 From: Anton Kolesnyk <41349689+antkmsft@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:55:16 -0800 Subject: [PATCH] Per-credential-instance token cache (#4160) * Per-credential-instance token cache Co-authored-by: Anton Kolesnyk --- sdk/identity/azure-identity/CHANGELOG.md | 3 + sdk/identity/azure-identity/CMakeLists.txt | 3 +- .../azure/identity/azure_cli_credential.hpp | 3 + .../client_certificate_credential.hpp | 3 +- .../identity/client_secret_credential.hpp | 2 + .../inc/azure/identity/detail/token_cache.hpp | 87 +++++++ .../src/azure_cli_credential.cpp | 39 ++-- .../src/client_certificate_credential.cpp | 154 +++++++------ .../src/client_secret_credential.cpp | 38 ++- .../src/managed_identity_source.cpp | 182 +++++++-------- .../src/private/managed_identity_source.hpp | 3 + .../src/private/token_cache_internals.hpp | 92 -------- .../azure-identity/src/token_cache.cpp | 100 +++----- .../test/ut/client_secret_credential_test.cpp | 4 +- .../test/ut/credential_test_helper.cpp | 14 +- .../test/ut/credential_test_helper.hpp | 8 +- .../test/ut/environment_credential_test.cpp | 14 +- .../ut/managed_identity_credential_test.cpp | 30 +-- .../test/ut/token_cache_test.cpp | 217 ++++++++---------- .../test/ut/token_credential_impl_test.cpp | 21 +- .../test/ut/token_credential_test.cpp | 14 +- 21 files changed, 486 insertions(+), 545 deletions(-) create mode 100644 sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp delete mode 100644 sdk/identity/azure-identity/src/private/token_cache_internals.hpp diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 1eb114046..847ef2a8f 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -12,6 +12,9 @@ ### Other Changes +- Changed token cache mode to per-credential-instance. In order to get benefits from token caching, share the same credential between multiple client instances. +- Added token cache support to all credentials. + ## 1.4.0-beta.2 (2022-11-08) ### Features Added diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index 6232727af..97c9aa2af 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -46,6 +46,7 @@ endif() set( AZURE_IDENTITY_HEADER + inc/azure/identity/detail/token_cache.hpp inc/azure/identity/azure_cli_credential.hpp inc/azure/identity/chained_token_credential.hpp inc/azure/identity/client_certificate_credential.hpp @@ -61,8 +62,6 @@ set( AZURE_IDENTITY_SOURCE src/private/managed_identity_source.hpp src/private/package_version.hpp - src/private/token_cache.hpp - src/private/token_cache_internals.hpp src/private/token_credential_impl.hpp src/azure_cli_credential.cpp src/chained_token_credential.cpp diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp index 019793c9c..946c2293a 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -8,6 +8,8 @@ #pragma once +#include "azure/identity/detail/token_cache.hpp" + #include #include @@ -46,6 +48,7 @@ namespace Azure { namespace Identity { #endif : public Core::Credentials::TokenCredential { protected: + _detail::TokenCache m_tokenCache; std::string m_tenantId; DateTime::duration m_cliProcessTimeout; diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp index 620eedbea..e1d712a28 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp @@ -8,7 +8,7 @@ #pragma once -#include "azure/identity/dll_import_export.hpp" +#include "azure/identity/detail/token_cache.hpp" #include #include @@ -37,6 +37,7 @@ namespace Azure { namespace Identity { */ class ClientCertificateCredential final : public Core::Credentials::TokenCredential { private: + _detail::TokenCache m_tokenCache; std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; Core::Url m_requestUrl; std::string m_requestBody; diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp index ed72b29cf..5a3a2cae8 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp @@ -8,6 +8,7 @@ #pragma once +#include "azure/identity/detail/token_cache.hpp" #include "azure/identity/dll_import_export.hpp" #include @@ -47,6 +48,7 @@ namespace Azure { namespace Identity { */ class ClientSecretCredential final : public Core::Credentials::TokenCredential { private: + _detail::TokenCache m_tokenCache; std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; Core::Url m_requestUrl; std::string m_requestBody; diff --git a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp new file mode 100644 index 000000000..523698ded --- /dev/null +++ b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Token cache. + * + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace Azure { namespace Identity { namespace _detail { + /** + * @brief Access token cache. + * + */ + class TokenCache +#if !defined(TESTING_BUILD) + final +#endif + { +#if !defined(TESTING_BUILD) + private: +#else + protected: +#endif + // A test hook that gets invoked before cache write lock gets acquired. + virtual void OnBeforeCacheWriteLock() const {}; + + // A test hook that gets invoked before item write lock gets acquired. + virtual void OnBeforeItemWriteLock() const {}; + + struct CacheValue + { + Core::Credentials::AccessToken AccessToken; + std::shared_timed_mutex ElementMutex; + }; + + mutable std::map> m_cache; + mutable std::shared_timed_mutex m_cacheMutex; + + private: + TokenCache(TokenCache const&) = delete; + TokenCache& operator=(TokenCache const&) = delete; + + // Checks cache element if cached value should be reused. Caller should be holding ElementMutex. + static bool IsFresh( + std::shared_ptr const& item, + DateTime::duration minimumExpiration, + std::chrono::system_clock::time_point now); + + // Gets item from cache, or creates it, puts into cache, and returns. + std::shared_ptr GetOrCreateValue( + std::string const& key, + DateTime::duration minimumExpiration) const; + + public: + TokenCache() = default; + ~TokenCache() = default; + + /** + * @brief Attempts to get token from cache, and if not found, gets the token using the function + * provided, caches it, and returns its value. + * + * @param scopeString Authentication scopes (or resource) as string. + * @param minimumExpiration Minimum token lifetime for the cached value to be returned. + * @param getNewToken Function to get the new token for the given \p scopeString, in case when + * cache does not have it, or if its remaining lifetime is less than \p minimumExpiration. + * + * @return Authentication token. + * + */ + Core::Credentials::AccessToken GetToken( + std::string const& scopeString, + DateTime::duration minimumExpiration, + std::function const& getNewToken) const; + }; +}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/azure_cli_credential.cpp b/sdk/identity/azure-identity/src/azure_cli_credential.cpp index a813d22ab..41dd3280b 100644 --- a/sdk/identity/azure-identity/src/azure_cli_credential.cpp +++ b/sdk/identity/azure-identity/src/azure_cli_credential.cpp @@ -44,6 +44,7 @@ using Azure::Core::Credentials::AuthenticationException; using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Identity::AzureCliCredentialOptions; +using Azure::Identity::_detail::TokenCache; using Azure::Identity::_detail::TokenCredentialImpl; namespace { @@ -121,28 +122,34 @@ AccessToken AzureCliCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - try - { - auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false); - - auto const azCliResult - = RunShellCommand(GetAzCommand(scopes, m_tenantId), m_cliProcessTimeout, context); + auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false); + // TokenCache::GetToken() can only use the lambda argument when they are being executed. They are + // not supposed to keep a reference to lambda argument to call it later. Therefore, any capture + // made here will outlive the possible time frame when the lambda might get called. + return m_tokenCache.GetToken(scopes, tokenRequestContext.MinimumExpiration, [&]() { try { - return TokenCredentialImpl::ParseToken(azCliResult, "accessToken", "expiresIn", "expiresOn"); + auto const azCliResult + = RunShellCommand(GetAzCommand(scopes, m_tenantId), m_cliProcessTimeout, context); + + try + { + return TokenCredentialImpl::ParseToken( + azCliResult, "accessToken", "expiresIn", "expiresOn"); + } + catch (std::exception const&) + { + // Throw the az command output (error message) + // limited to 250 characters (250 has no special meaning). + throw std::runtime_error(azCliResult.substr(0, 250)); + } } - catch (std::exception const&) + catch (std::exception const& e) { - // Throw the az command output (error message) - // limited to 250 characters (250 has no special meaning). - throw std::runtime_error(azCliResult.substr(0, 250)); + throw AuthenticationException(std::string("AzureCliCredential::GetToken(): ") + e.what()); } - } - catch (std::exception const& e) - { - throw AuthenticationException(std::string("AzureCliCredential::GetToken(): ") + e.what()); - } + }); } namespace { diff --git a/sdk/identity/azure-identity/src/client_certificate_credential.cpp b/sdk/identity/azure-identity/src/client_certificate_credential.cpp index 9ecbbd2d3..9c4a60540 100644 --- a/sdk/identity/azure-identity/src/client_certificate_credential.cpp +++ b/sdk/identity/azure-identity/src/client_certificate_credential.cpp @@ -196,100 +196,114 @@ Azure::Core::Credentials::AccessToken ClientCertificateCredential::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - return m_tokenCredentialImpl->GetToken(context, [&]() { - using _detail::TokenCredentialImpl; - using Azure::Core::Http::HttpMethod; - - std::ostringstream body; - body << m_requestBody; + using _detail::TokenCredentialImpl; + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - body << "&scope=" << TokenCredentialImpl::FormatScopes(scopes, false); - } + scopesStr = TokenCredentialImpl::FormatScopes(scopes, false); } + } - std::string assertion = m_tokenHeaderEncoded; - { - using Azure::Core::_internal::Base64Url; - // Form the assertion to sign. + // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCredentialImpl->GetToken(context, [&]() { + using Azure::Core::Http::HttpMethod; + + std::ostringstream body; + body << m_requestBody; { - std::string payloadStr; - // Add GUID, current time, and expiration time to the payload + if (!scopesStr.empty()) { - using Azure::Core::Uuid; - using Azure::Core::_internal::PosixTimeConverter; + body << "&scope=" << scopesStr; + } + } - std::ostringstream payloadStream; + std::string assertion = m_tokenHeaderEncoded; + { + using Azure::Core::_internal::Base64Url; + // Form the assertion to sign. + { + std::string payloadStr; + // Add GUID, current time, and expiration time to the payload + { + using Azure::Core::Uuid; + using Azure::Core::_internal::PosixTimeConverter; - const Azure::DateTime now = std::chrono::system_clock::now(); - const Azure::DateTime exp = now + std::chrono::minutes(10); + std::ostringstream payloadStream; - payloadStream << m_tokenPayloadStaticPart << Uuid::CreateUuid().ToString() - << "\",\"nbf\":" << PosixTimeConverter::DateTimeToPosixTime(now) - << ",\"exp\":" << PosixTimeConverter::DateTimeToPosixTime(exp) << "}"; + const Azure::DateTime now = std::chrono::system_clock::now(); + const Azure::DateTime exp = now + std::chrono::minutes(10); - payloadStr = payloadStream.str(); + payloadStream << m_tokenPayloadStaticPart << Uuid::CreateUuid().ToString() + << "\",\"nbf\":" << PosixTimeConverter::DateTimeToPosixTime(now) + << ",\"exp\":" << PosixTimeConverter::DateTimeToPosixTime(exp) << "}"; + + payloadStr = payloadStream.str(); + } + + // Concatenate JWT token header + "." + encoded payload + const auto payloadVec + = std::vector(payloadStr.begin(), payloadStr.end()); + + assertion += std::string(".") + Base64Url::Base64UrlEncode(ToUInt8Vector(payloadVec)); } - // Concatenate JWT token header + "." + encoded payload - const auto payloadVec - = std::vector(payloadStr.begin(), payloadStr.end()); - - assertion += std::string(".") + Base64Url::Base64UrlEncode(ToUInt8Vector(payloadVec)); - } - - // Get assertion signature. - std::string signature; - if (auto mdCtx = EVP_MD_CTX_new()) - { - try + // Get assertion signature. + std::string signature; + if (auto mdCtx = EVP_MD_CTX_new()) { - EVP_PKEY_CTX* signCtx = nullptr; - if ((EVP_DigestSignInit( - mdCtx, &signCtx, EVP_sha256(), nullptr, static_cast(m_pkey)) - == 1) - && (EVP_PKEY_CTX_set_rsa_padding(signCtx, RSA_PKCS1_PADDING) == 1)) + try { - size_t sigLen = 0; - if (EVP_DigestSign(mdCtx, nullptr, &sigLen, nullptr, 0) == 1) + EVP_PKEY_CTX* signCtx = nullptr; + if ((EVP_DigestSignInit( + mdCtx, &signCtx, EVP_sha256(), nullptr, static_cast(m_pkey)) + == 1) + && (EVP_PKEY_CTX_set_rsa_padding(signCtx, RSA_PKCS1_PADDING) == 1)) { - const auto bufToSign = reinterpret_cast(assertion.data()); - const auto bufToSignLen = static_cast(assertion.size()); - - std::vector sigVec(sigLen); - if (EVP_DigestSign(mdCtx, sigVec.data(), &sigLen, bufToSign, bufToSignLen) == 1) + size_t sigLen = 0; + if (EVP_DigestSign(mdCtx, nullptr, &sigLen, nullptr, 0) == 1) { - signature = Base64Url::Base64UrlEncode(ToUInt8Vector(sigVec)); + const auto bufToSign = reinterpret_cast(assertion.data()); + const auto bufToSignLen = static_cast(assertion.size()); + + std::vector sigVec(sigLen); + if (EVP_DigestSign(mdCtx, sigVec.data(), &sigLen, bufToSign, bufToSignLen) == 1) + { + signature = Base64Url::Base64UrlEncode(ToUInt8Vector(sigVec)); + } } } - } - if (signature.empty()) + if (signature.empty()) + { + throw Azure::Core::Credentials::AuthenticationException( + "Failed to sign token request."); + } + + EVP_MD_CTX_free(mdCtx); + } + catch (...) { - throw Azure::Core::Credentials::AuthenticationException( - "Failed to sign token request."); + EVP_MD_CTX_free(mdCtx); + throw; } + } - EVP_MD_CTX_free(mdCtx); - } - catch (...) - { - EVP_MD_CTX_free(mdCtx); - throw; - } + // Add signature to the end of assertion + assertion += std::string(".") + signature; } - // Add signature to the end of assertion - assertion += std::string(".") + signature; - } + body << "&client_assertion=" << Azure::Core::Url::Encode(assertion); - body << "&client_assertion=" << Azure::Core::Url::Encode(assertion); + auto request = std::make_unique( + HttpMethod::Post, m_requestUrl, body.str()); - auto request = std::make_unique( - HttpMethod::Post, m_requestUrl, body.str()); - - return request; + return request; + }); }); } diff --git a/sdk/identity/azure-identity/src/client_secret_credential.cpp b/sdk/identity/azure-identity/src/client_secret_credential.cpp index 0cb63c787..8f7073f2e 100644 --- a/sdk/identity/azure-identity/src/client_secret_credential.cpp +++ b/sdk/identity/azure-identity/src/client_secret_credential.cpp @@ -3,7 +3,6 @@ #include "azure/identity/client_secret_credential.hpp" -#include "private/token_cache.hpp" #include "private/token_credential_impl.hpp" #include @@ -68,7 +67,6 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - using _detail::TokenCache; using _detail::TokenCredentialImpl; std::string scopesStr; @@ -84,30 +82,24 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return TokenCache::GetToken( - m_tenantId, - m_clientId, - m_authorityHost, - scopesStr, - tokenRequestContext.MinimumExpiration, - [&]() { - return m_tokenCredentialImpl->GetToken(context, [&]() { - using Azure::Core::Http::HttpMethod; + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCredentialImpl->GetToken(context, [&]() { + using Azure::Core::Http::HttpMethod; - std::ostringstream body; - body << m_requestBody; + std::ostringstream body; + body << m_requestBody; - if (!scopesStr.empty()) - { - body << "&scope=" << scopesStr; - } + if (!scopesStr.empty()) + { + body << "&scope=" << scopesStr; + } - auto request = std::make_unique( - HttpMethod::Post, m_requestUrl, body.str()); + auto request = std::make_unique( + HttpMethod::Post, m_requestUrl, body.str()); - request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); + request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); - return request; - }); - }); + return request; + }); + }); } diff --git a/sdk/identity/azure-identity/src/managed_identity_source.cpp b/sdk/identity/azure-identity/src/managed_identity_source.cpp index 3dcb4d38a..dc4af9d04 100644 --- a/sdk/identity/azure-identity/src/managed_identity_source.cpp +++ b/sdk/identity/azure-identity/src/managed_identity_source.cpp @@ -3,8 +3,6 @@ #include "private/managed_identity_source.hpp" -#include "private/token_cache.hpp" - #include #include @@ -96,24 +94,18 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return TokenCache::GetToken( - std::string(), - GetClientId(), - GetAuthorityHost(), - scopesStr, - tokenRequestContext.MinimumExpiration, - [&]() { - return TokenCredentialImpl::GetToken(context, [&]() { - auto request = std::make_unique(m_request); + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + auto request = std::make_unique(m_request); - if (!scopesStr.empty()) - { - request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); - } + if (!scopesStr.empty()) + { + request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); + } - return request; - }); - }); + return request; + }); + }); } std::unique_ptr AppServiceV2017ManagedIdentitySource::Create( @@ -175,34 +167,28 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return TokenCache::GetToken( - std::string(), - GetClientId(), - GetAuthorityHost(), - scopesStr, - tokenRequestContext.MinimumExpiration, - [&]() { - return TokenCredentialImpl::GetToken(context, [&]() { - using Azure::Core::Url; - using Azure::Core::Http::HttpMethod; + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + using Azure::Core::Url; + using Azure::Core::Http::HttpMethod; - std::string resource; + std::string resource; - if (!scopesStr.empty()) - { - resource = "resource=" + scopesStr; - if (!m_body.empty()) - { - resource += "&"; - } - } + if (!scopesStr.empty()) + { + resource = "resource=" + scopesStr; + if (!m_body.empty()) + { + resource += "&"; + } + } - auto request = std::make_unique(HttpMethod::Post, m_url, resource + m_body); - request->HttpRequest.SetHeader("Metadata", "true"); + auto request = std::make_unique(HttpMethod::Post, m_url, resource + m_body); + request->HttpRequest.SetHeader("Metadata", "true"); - return request; - }); - }); + return request; + }); + }); } std::unique_ptr AzureArcManagedIdentitySource::Create( @@ -276,57 +262,51 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return TokenCache::GetToken( - std::string(), - GetClientId(), - GetAuthorityHost(), - scopesStr, - tokenRequestContext.MinimumExpiration, - [&]() { - return TokenCredentialImpl::GetToken( - context, - createRequest, - [&](auto const statusCode, auto const& response) -> std::unique_ptr { - using Core::Credentials::AuthenticationException; - using Core::Http::HttpStatusCode; + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return TokenCredentialImpl::GetToken( + context, + createRequest, + [&](auto const statusCode, auto const& response) -> std::unique_ptr { + using Core::Credentials::AuthenticationException; + using Core::Http::HttpStatusCode; - if (statusCode != HttpStatusCode::Unauthorized) - { - return nullptr; - } + if (statusCode != HttpStatusCode::Unauthorized) + { + return nullptr; + } - auto const& headers = response.GetHeaders(); - auto authHeader = headers.find("WWW-Authenticate"); - if (authHeader == headers.end()) - { - throw AuthenticationException( - "Did not receive expected WWW-Authenticate header " - "in the response from Azure Arc Managed Identity Endpoint."); - } + auto const& headers = response.GetHeaders(); + auto authHeader = headers.find("WWW-Authenticate"); + if (authHeader == headers.end()) + { + throw AuthenticationException( + "Did not receive expected WWW-Authenticate header " + "in the response from Azure Arc Managed Identity Endpoint."); + } - constexpr auto ChallengeValueSeparator = '='; - auto const& challenge = authHeader->second; - auto eq = challenge.find(ChallengeValueSeparator); - if (eq == std::string::npos - || challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos) - { - throw AuthenticationException( - "The WWW-Authenticate header in the response from Azure Arc " - "Managed Identity Endpoint did not match the expected format."); - } + constexpr auto ChallengeValueSeparator = '='; + auto const& challenge = authHeader->second; + auto eq = challenge.find(ChallengeValueSeparator); + if (eq == std::string::npos + || challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos) + { + throw AuthenticationException( + "The WWW-Authenticate header in the response from Azure Arc " + "Managed Identity Endpoint did not match the expected format."); + } - auto request = createRequest(); - std::ifstream secretFile(challenge.substr(eq + 1)); - request->HttpRequest.SetHeader( - "Authorization", - "Basic " - + std::string( - std::istreambuf_iterator(secretFile), - std::istreambuf_iterator())); + auto request = createRequest(); + std::ifstream secretFile(challenge.substr(eq + 1)); + request->HttpRequest.SetHeader( + "Authorization", + "Basic " + + std::string( + std::istreambuf_iterator(secretFile), + std::istreambuf_iterator())); - return request; - }); - }); + return request; + }); + }); } std::unique_ptr ImdsManagedIdentitySource::Create( @@ -376,22 +356,16 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return TokenCache::GetToken( - std::string(), - GetClientId(), - GetAuthorityHost(), - scopesStr, - tokenRequestContext.MinimumExpiration, - [&]() { - return TokenCredentialImpl::GetToken(context, [&]() { - auto request = std::make_unique(m_request); + return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + auto request = std::make_unique(m_request); - if (!scopesStr.empty()) - { - request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); - } + if (!scopesStr.empty()) + { + request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); + } - return request; - }); - }); + return request; + }); + }); } diff --git a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp index 04ec7b6a7..cf80baf9e 100644 --- a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp +++ b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp @@ -3,6 +3,8 @@ #pragma once +#include "azure/identity/detail/token_cache.hpp" + #include #include #include @@ -25,6 +27,7 @@ namespace Azure { namespace Identity { namespace _detail { Core::Context const& context) const = 0; protected: + _detail::TokenCache m_tokenCache; static Core::Url ParseEndpointUrl(std::string const& url, char const* envVarName); explicit ManagedIdentitySource( diff --git a/sdk/identity/azure-identity/src/private/token_cache_internals.hpp b/sdk/identity/azure-identity/src/private/token_cache_internals.hpp deleted file mode 100644 index b1e612ce9..000000000 --- a/sdk/identity/azure-identity/src/private/token_cache_internals.hpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// SPDX-License-Identifier: MIT - -/** - * @file - * @brief Token cache internals and test hooks. - */ - -#pragma once - -#include "token_cache.hpp" - -#if defined(TESTING_BUILD) -#include "azure/identity/dll_import_export.hpp" -#endif - -#include - -#include -#include -#include -#include -#include -#include - -namespace Azure { namespace Identity { namespace _detail { - /** - * @brief Implements internal aspects of token cache and provides test hooks. - * - */ - class TokenCache::Internals final { - Internals() = delete; - ~Internals() = delete; - - public: - /** - * @brief Represents a unique set of characteristics that are used to distinguish between cache - * entries. - * - */ - struct CacheKey final - { - std::string TenantId; ///< Tenant ID. - std::string ClientId; ///< Client ID. - std::string AuthorityHost; ///< Authority Host. - std::string Scopes; ///< Authentication Scopes as a single string. - - bool operator<(TokenCache::Internals::CacheKey const& other) const - { - return std::tie(TenantId, ClientId, AuthorityHost, Scopes) - < std::tie(other.TenantId, other.ClientId, other.AuthorityHost, other.Scopes); - } - }; - - /** - * @brief Represents immediate cache value (token) and a synchronization primitive to handle its - * updates. - * - */ - struct CacheValue final - { - std::shared_timed_mutex ElementMutex; - Core::Credentials::AccessToken AccessToken; - }; - - /** - * @brief The cache itself. - * - */ - static std::map> Cache; - - /** - * @brief Mutex to access the cache container. - * - */ - static std::shared_timed_mutex CacheMutex; - -#if defined(TESTING_BUILD) - /** - * A test hook that gets invoked before cache write lock gets acquired. - * - */ - AZ_IDENTITY_DLLEXPORT static std::function OnBeforeCacheWriteLock; - - /** - * A test hook that gets invoked before item write lock gets acquired. - * - */ - AZ_IDENTITY_DLLEXPORT static std::function OnBeforeItemWriteLock; -#endif - }; -}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/token_cache.cpp b/sdk/identity/azure-identity/src/token_cache.cpp index 4f88476e8..18e66cfcf 100644 --- a/sdk/identity/azure-identity/src/token_cache.cpp +++ b/sdk/identity/azure-identity/src/token_cache.cpp @@ -1,10 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-License-Identifier: MIT -#include "private/token_cache.hpp" -#include "private/token_cache_internals.hpp" +#include "azure/identity/detail/token_cache.hpp" -#include #include using Azure::Identity::_detail::TokenCache; @@ -12,57 +10,45 @@ using Azure::Identity::_detail::TokenCache; using Azure::DateTime; using Azure::Core::Credentials::AccessToken; -decltype(TokenCache::Internals::Cache) TokenCache::Internals::Cache; -decltype(TokenCache::Internals::CacheMutex) TokenCache::Internals::CacheMutex; - -#if defined(TESTING_BUILD) -std::function TokenCache::Internals::OnBeforeCacheWriteLock; -std::function TokenCache::Internals::OnBeforeItemWriteLock; -#endif - -namespace { -bool IsFresh( - std::shared_ptr const& item, +bool TokenCache::IsFresh( + std::shared_ptr const& item, DateTime::duration minimumExpiration, std::chrono::system_clock::time_point now) { return item->AccessToken.ExpiresOn > (DateTime(now) + minimumExpiration); } -std::shared_ptr GetOrCreateValue( - TokenCache::Internals::CacheKey const& key, - DateTime::duration minimumExpiration) +std::shared_ptr TokenCache::GetOrCreateValue( + std::string const& key, + DateTime::duration minimumExpiration) const { { - std::shared_lock cacheReadLock(TokenCache::Internals::CacheMutex); + std::shared_lock cacheReadLock(m_cacheMutex); - auto const found = TokenCache::Internals::Cache.find(key); - if (found != TokenCache::Internals::Cache.end()) + auto const found = m_cache.find(key); + if (found != TokenCache::m_cache.end()) { return found->second; } } #if defined(TESTING_BUILD) - if (TokenCache::Internals::OnBeforeCacheWriteLock != nullptr) - { - TokenCache::Internals::OnBeforeCacheWriteLock(); - } + OnBeforeCacheWriteLock(); #endif - std::unique_lock cacheWriteLock(TokenCache::Internals::CacheMutex); + std::unique_lock cacheWriteLock(m_cacheMutex); // Search cache for the second time, in case the item was inserted between releasing the read lock // and acquiring the write lock. - auto const found = TokenCache::Internals::Cache.find(key); - if (found != TokenCache::Internals::Cache.end()) + auto const found = m_cache.find(key); + if (found != m_cache.end()) { return found->second; } // Clean up cache from expired items (once every N insertions). { - auto const cacheSize = TokenCache::Internals::Cache.size(); + auto const cacheSize = m_cache.size(); // N: cacheSize (before insertion) is >= 32 and is a power of two. // 32 as a starting point does not have any special meaning. @@ -74,8 +60,8 @@ std::shared_ptr GetOrCreateValue( { auto now = std::chrono::system_clock::now(); - auto iter = TokenCache::Internals::Cache.begin(); - while (iter != TokenCache::Internals::Cache.end()) + auto iter = m_cache.begin(); + while (iter != m_cache.end()) { // Should we end up erasing the element, iterator to current will become invalid, after // which we can't increment it. So we copy current, and safely advance the loop iterator. @@ -90,7 +76,7 @@ std::shared_ptr GetOrCreateValue( std::unique_lock lock(item->ElementMutex, std::defer_lock); if (lock.try_lock() && !IsFresh(item, minimumExpiration, now)) { - TokenCache::Internals::Cache.erase(curr); + m_cache.erase(curr); } } } @@ -98,20 +84,15 @@ std::shared_ptr GetOrCreateValue( } // Insert the blank value value and return it. - return TokenCache::Internals::Cache[key] = std::make_shared(); + return m_cache[key] = std::make_shared(); } -} // namespace AccessToken TokenCache::GetToken( - std::string const& tenantId, - std::string const& clientId, - std::string const& authorityHost, - std::string const& scopes, + std::string const& scopeString, DateTime::duration minimumExpiration, - std::function const& getNewToken) + std::function const& getNewToken) const { - auto const item - = GetOrCreateValue({tenantId, clientId, authorityHost, scopes}, minimumExpiration); + auto const item = GetOrCreateValue(scopeString, minimumExpiration); { std::shared_lock itemReadLock(item->ElementMutex); @@ -122,33 +103,20 @@ AccessToken TokenCache::GetToken( } } +#if defined(TESTING_BUILD) + OnBeforeItemWriteLock(); +#endif + + std::unique_lock itemWriteLock(item->ElementMutex); + + // Check the expiration for the second time, in case it just got updated, after releasing the + // itemReadLock, and before acquiring itemWriteLock. + if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now())) { -#if defined(TESTING_BUILD) - if (TokenCache::Internals::OnBeforeItemWriteLock != nullptr) - { - TokenCache::Internals::OnBeforeItemWriteLock(); - } -#endif - - std::unique_lock itemWriteLock(item->ElementMutex); - - // Check the expiration for the second time, in case it just got updated, after releasing the - // itemReadLock, and before acquiring itemWriteLock. - if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now())) - { - return item->AccessToken; - } - - auto const newToken = getNewToken(); - item->AccessToken = newToken; - return newToken; + return item->AccessToken; } -} -#if defined(TESTING_BUILD) -void TokenCache::Clear() -{ - std::unique_lock cacheWriteLock(TokenCache::Internals::CacheMutex); - Internals::Cache.clear(); + auto const newToken = getNewToken(); + item->AccessToken = newToken; + return newToken; } -#endif diff --git a/sdk/identity/azure-identity/test/ut/client_secret_credential_test.cpp b/sdk/identity/azure-identity/test/ut/client_secret_credential_test.cpp index 1777553da..7625bc8e6 100644 --- a/sdk/identity/azure-identity/test/ut/client_secret_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/client_secret_credential_test.cpp @@ -172,7 +172,7 @@ TEST(ClientSecretCredential, Authority) "CLIENTSECRET1", options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); auto const actual2 = CredentialTestHelper::SimulateTokenRequest( @@ -184,7 +184,7 @@ TEST(ClientSecretCredential, Authority) return std::make_unique( "adfs", "01234567-89ab-cdef-fedc-ba8976543210", "CLIENTSECRET2", options); }, - {{{"https://outlook.com/.default"}}}, + {{"https://outlook.com/.default"}}, {"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"}); EXPECT_EQ(actual1.Requests.size(), 1U); diff --git a/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp b/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp index 6285b4e78..cb3b509eb 100644 --- a/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp +++ b/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp @@ -3,7 +3,6 @@ #include "credential_test_helper.hpp" -#include "private/token_cache_internals.hpp" #include #include @@ -67,19 +66,18 @@ CredentialTestHelper::GetTokenCallback const CredentialTestHelper::DefaultGetTok CredentialTestHelper::TokenRequestSimulationResult CredentialTestHelper::SimulateTokenRequest( CredentialTestHelper::CreateCredentialCallback const& createCredential, - std::vector const& tokenRequestContexts, + std::vector const& + tokenRequestContextScopes, std::vector const& responses, GetTokenCallback getToken) { - Azure::Identity::_detail::TokenCache::Clear(); - using Azure::Core::Context; using Azure::Core::Http::HttpStatusCode; using Azure::Core::Http::RawResponse; using Azure::Core::IO::MemoryBodyStream; auto const nResponses = responses.size(); - auto const nRequestTimes = tokenRequestContexts.size(); + auto const nRequestTimes = tokenRequestContextScopes.size(); TokenRequestSimulationResult result; { @@ -139,7 +137,11 @@ CredentialTestHelper::TokenRequestSimulationResult CredentialTestHelper::Simulat { TokenRequestSimulationResult::ResponseInfo response{}; - response.AccessToken = getToken(*credential, tokenRequestContexts.at(i), Context()); + Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = tokenRequestContextScopes.at(i); + tokenRequestContext.MinimumExpiration = std::chrono::hours(1000000); + + response.AccessToken = getToken(*credential, tokenRequestContext, Context()); response.EarliestExpiration = earliestExpiration; response.LatestExpiration = std::chrono::system_clock::now(); diff --git a/sdk/identity/azure-identity/test/ut/credential_test_helper.hpp b/sdk/identity/azure-identity/test/ut/credential_test_helper.hpp index bd35726c6..244ffcdac 100644 --- a/sdk/identity/azure-identity/test/ut/credential_test_helper.hpp +++ b/sdk/identity/azure-identity/test/ut/credential_test_helper.hpp @@ -75,13 +75,15 @@ namespace Azure { namespace Identity { namespace Test { namespace _detail { static TokenRequestSimulationResult SimulateTokenRequest( CreateCredentialCallback const& createCredential, - std::vector const& tokenRequestContexts, + std::vector const& + tokenRequestContextScopes, std::vector const& responses, GetTokenCallback getToken = DefaultGetToken); static TokenRequestSimulationResult SimulateTokenRequest( CreateCredentialCallback const& createCredential, - std::vector const& tokenRequestContexts, + std::vector const& + tokenRequestContextScopes, std::vector const& responseBodies, GetTokenCallback getToken = DefaultGetToken) { @@ -92,7 +94,7 @@ namespace Azure { namespace Identity { namespace Test { namespace _detail { responses.push_back({HttpStatusCode::Ok, responseBody, {}}); } - return SimulateTokenRequest(createCredential, tokenRequestContexts, responses, getToken); + return SimulateTokenRequest(createCredential, tokenRequestContextScopes, responses, getToken); } }; }}}} // namespace Azure::Identity::Test::_detail diff --git a/sdk/identity/azure-identity/test/ut/environment_credential_test.cpp b/sdk/identity/azure-identity/test/ut/environment_credential_test.cpp index f7c07e657..d7ebfacf1 100644 --- a/sdk/identity/azure-identity/test/ut/environment_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/environment_credential_test.cpp @@ -30,7 +30,7 @@ TEST(EnvironmentCredential, RegularClientSecretCredential) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); EXPECT_EQ(actual.Requests.size(), 1U); @@ -85,7 +85,7 @@ TEST(EnvironmentCredential, AzureStackClientSecretCredential) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); EXPECT_EQ(actual.Requests.size(), 1U); @@ -141,7 +141,7 @@ TEST(EnvironmentCredential, Unavailable) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -169,7 +169,7 @@ TEST(EnvironmentCredential, ClientSecretDefaultAuthority) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); EXPECT_EQ(actual.Requests.size(), 1U); @@ -227,7 +227,7 @@ TEST(EnvironmentCredential, ClientSecretNoTenantId) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -258,7 +258,7 @@ TEST(EnvironmentCredential, ClientSecretNoClientId) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -289,7 +289,7 @@ TEST(EnvironmentCredential, ClientSecretNoClientSecret) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; diff --git a/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp b/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp index 27f7d5b80..8ae5e8eb3 100644 --- a/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp @@ -33,7 +33,7 @@ TEST(ManagedIdentityCredential, AppServiceV2019) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -120,7 +120,7 @@ TEST(ManagedIdentityCredential, AppServiceV2019ClientId) return std::make_unique( "fedcba98-7654-3210-0123-456789abcdef", options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -273,7 +273,7 @@ TEST(ManagedIdentityCredential, AppServiceV2017) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -360,7 +360,7 @@ TEST(ManagedIdentityCredential, AppServiceV2017ClientId) return std::make_unique( "fedcba98-7654-3210-0123-456789abcdef", options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -513,7 +513,7 @@ TEST(ManagedIdentityCredential, CloudShell) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -587,7 +587,7 @@ TEST(ManagedIdentityCredential, CloudShellClientId) return std::make_unique( "fedcba98-7654-3210-0123-456789abcdef", options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -718,7 +718,7 @@ TEST(ManagedIdentityCredential, AzureArc) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, {{HttpStatusCode::Unauthorized, "", {{"WWW-Authenticate", "ABC ABC=managed_identity_credential_test1.txt"}}}, @@ -883,7 +883,7 @@ TEST(ManagedIdentityCredential, AzureArcAuthHeaderMissing) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {{HttpStatusCode::Unauthorized, "", {}}, {HttpStatusCode::Ok, "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", {}}}, [](auto& credential, auto& tokenRequestContext, auto& context) { @@ -922,7 +922,7 @@ TEST(ManagedIdentityCredential, AzureArcUnexpectedHttpStatusCode) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {{HttpStatusCode::Forbidden, "", {{"WWW-Authenticate", "ABC ABC=managed_identity_credential_test0.txt"}}}, @@ -956,7 +956,7 @@ TEST(ManagedIdentityCredential, AzureArcAuthHeaderNoEquals) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {{HttpStatusCode::Unauthorized, "", {{"WWW-Authenticate", "ABCSECRET1"}}}, {HttpStatusCode::Ok, "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", {}}}, [](auto& credential, auto& tokenRequestContext, auto& context) { @@ -988,7 +988,7 @@ TEST(ManagedIdentityCredential, AzureArcAuthHeaderTwoEquals) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {{HttpStatusCode::Unauthorized, "", {{"WWW-Authenticate", "ABC=SECRET1=SECRET2"}}}, {HttpStatusCode::Ok, "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", {}}}, [](auto& credential, auto& tokenRequestContext, auto& context) { @@ -1048,7 +1048,7 @@ TEST(ManagedIdentityCredential, Imds) return std::make_unique(options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -1135,7 +1135,7 @@ TEST(ManagedIdentityCredential, ImdsClientId) return std::make_unique( "fedcba98-7654-3210-0123-456789abcdef", options); }, - {{{"https://azure.com/.default"}}, {{"https://outlook.com/.default"}}, {{}}}, + {{"https://azure.com/.default"}, {"https://outlook.com/.default"}, {}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}", @@ -1225,7 +1225,7 @@ TEST(ManagedIdentityCredential, ImdsCreation) return std::make_unique( "fedcba98-7654-3210-0123-456789abcdef", options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); auto const actual2 = CredentialTestHelper::SimulateTokenRequest( @@ -1245,7 +1245,7 @@ TEST(ManagedIdentityCredential, ImdsCreation) return std::make_unique( "01234567-89ab-cdef-fedc-ba9876543210", options); }, - {{{"https://outlook.com/.default"}}}, + {{"https://outlook.com/.default"}}, {"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"}); EXPECT_EQ(actual1.Requests.size(), 1U); diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp index 78c436822..cc72e64bf 100644 --- a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-License-Identifier: MIT -#include "private/token_cache_internals.hpp" +#include "azure/identity/detail/token_cache.hpp" #include @@ -11,74 +11,59 @@ using Azure::DateTime; using Azure::Core::Credentials::AccessToken; using Azure::Identity::_detail::TokenCache; -using namespace std::chrono_literals; +namespace { +class TestableTokenCache final : public TokenCache { +public: + using TokenCache::CacheValue; + using TokenCache::m_cache; + using TokenCache::m_cacheMutex; -TEST(TokenCache, KeyComparison) -{ - using Key = TokenCache::Internals::CacheKey; - Key const key1{"a", "b", "c", "d"}; - EXPECT_FALSE(key1 < key1); + mutable std::function m_onBeforeCacheWriteLock; + mutable std::function m_onBeforeItemWriteLock; + void OnBeforeCacheWriteLock() const override { - Key const key1dup{"a", "b", "c", "d"}; - - EXPECT_FALSE(key1 < key1dup); - EXPECT_FALSE(key1dup < key1); + if (m_onBeforeCacheWriteLock != nullptr) + { + m_onBeforeCacheWriteLock(); + } } - Key const key2{"a", "b", "c", "~"}; - Key const key3{"a", "b", "~", "d"}; - Key const key4{"a", "~", "c", "d"}; - Key const key5{"~", "b", "c", "d"}; + void OnBeforeItemWriteLock() const override + { + if (m_onBeforeItemWriteLock != nullptr) + { + m_onBeforeItemWriteLock(); + } + } +}; +} // namespace - EXPECT_TRUE(key1 < key2); - EXPECT_TRUE(key1 < key3); - EXPECT_TRUE(key1 < key4); - EXPECT_TRUE(key1 < key5); - EXPECT_FALSE(key2 < key1); - EXPECT_FALSE(key3 < key1); - EXPECT_FALSE(key4 < key1); - EXPECT_FALSE(key5 < key1); - - EXPECT_TRUE(key2 < key3); - EXPECT_TRUE(key2 < key4); - EXPECT_TRUE(key2 < key5); - EXPECT_FALSE(key3 < key2); - EXPECT_FALSE(key4 < key2); - EXPECT_FALSE(key5 < key2); - - EXPECT_TRUE(key3 < key4); - EXPECT_TRUE(key3 < key5); - EXPECT_FALSE(key4 < key3); - EXPECT_FALSE(key5 < key3); - - EXPECT_TRUE(key4 < key5); - EXPECT_FALSE(key5 < key4); -} +using namespace std::chrono_literals; TEST(TokenCache, GetReuseRefresh) { - TokenCache::Clear(); + TestableTokenCache tokenCache; - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; auto const Yesterday = Tomorrow - 48h; { - auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token1.ExpiresOn, Tomorrow); EXPECT_EQ(token1.Token, "T1"); - auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token2 = tokenCache.GetToken("A", 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); AccessToken result; result.Token = "T2"; @@ -86,23 +71,23 @@ TEST(TokenCache, GetReuseRefresh) return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); EXPECT_EQ(token1.Token, token2.Token); } { - TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache["A"]->AccessToken.ExpiresOn = Yesterday; - auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token = tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T3"; result.ExpiresOn = Tomorrow + 1min; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); EXPECT_EQ(token.Token, "T3"); @@ -111,15 +96,15 @@ TEST(TokenCache, GetReuseRefresh) TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) { - TokenCache::Clear(); + TestableTokenCache tokenCache; - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - TokenCache::Internals::OnBeforeCacheWriteLock = [=]() { - TokenCache::Internals::OnBeforeCacheWriteLock = nullptr; - static_cast(TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + tokenCache.m_onBeforeCacheWriteLock = [&]() { + tokenCache.m_onBeforeCacheWriteLock = nullptr; + static_cast(tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -127,7 +112,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) })); }; - auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token = tokenCache.GetToken("A", 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " "acquiring cache write lock"); AccessToken result; @@ -136,7 +121,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token.ExpiresOn, Tomorrow); EXPECT_EQ(token.Token, "T1"); @@ -144,22 +129,22 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) { - TokenCache::Clear(); - - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; auto const Yesterday = Tomorrow - 48h; { - TokenCache::Internals::OnBeforeItemWriteLock = [=]() { - TokenCache::Internals::OnBeforeItemWriteLock = nullptr; - auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}]; + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache["A"]; item->AccessToken.Token = "T1"; item->AccessToken.ExpiresOn = Tomorrow; }; - auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token = tokenCache.GetToken("A", 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " "acquiring item write lock"); AccessToken result; @@ -168,7 +153,7 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token.ExpiresOn, Tomorrow); EXPECT_EQ(token.Token, "T1"); @@ -176,23 +161,23 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) // Same as above, but the token that was inserted is already expired. { - TokenCache::Clear(); + TestableTokenCache tokenCache; - TokenCache::Internals::OnBeforeItemWriteLock = [=]() { - TokenCache::Internals::OnBeforeItemWriteLock = nullptr; - auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}]; + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache["A"]; item->AccessToken.Token = "T3"; item->AccessToken.ExpiresOn = Yesterday; }; - auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token = tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T4"; result.ExpiresOn = Tomorrow + 3min; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); EXPECT_EQ(token.Token, "T4"); @@ -204,13 +189,13 @@ TEST(TokenCache, ExpiredCleanup) DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; auto const Yesterday = Tomorrow - 48h; - TokenCache::Clear(); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + TestableTokenCache tokenCache; + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); for (auto i = 1; i <= 65; ++i) { auto const n = std::to_string(i); - static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -219,20 +204,20 @@ TEST(TokenCache, ExpiredCleanup) } // Simply: we added 64+1 token, none of them has expired. None are expected to be cleaned up. - EXPECT_EQ(TokenCache::Internals::Cache.size(), 65UL); + EXPECT_EQ(tokenCache.m_cache.size(), 65UL); // Let's expire 3 of them, with numbers from 1 to 3. for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; } // Add tokens up to 128 total. When 129th gets added, clean up should get triggered. for (auto i = 66; i <= 128; ++i) { auto const n = std::to_string(i); - static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -240,17 +225,17 @@ TEST(TokenCache, ExpiredCleanup) })); } - EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL); + EXPECT_EQ(tokenCache.m_cache.size(), 128UL); // Count is at 128. Tokens from 1 to 3 are still in cache even though they are expired. for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } // One more addition to the cache and cleanup for the expired ones will get triggered. - static_cast(TokenCache::GetToken("129", "129", "129", "129", 2min, [=]() { + static_cast(tokenCache.GetToken("129", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -258,20 +243,20 @@ TEST(TokenCache, ExpiredCleanup) })); // We were at 128 before we added 1 more, and now we're at 126. 3 were deleted, 1 was added. - EXPECT_EQ(TokenCache::Internals::Cache.size(), 126UL); + EXPECT_EQ(tokenCache.m_cache.size(), 126UL); // Items from 1 to 3 should no longer be in the cache. for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } // Let's expire items from 21 all the way up to 129. for (auto i = 21; i <= 129; ++i) { auto const n = std::to_string(i); - TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; } // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to @@ -279,7 +264,7 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 2; i <= 3; ++i) { auto const n = std::to_string(i); - static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, 2min, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow; @@ -288,26 +273,26 @@ TEST(TokenCache, ExpiredCleanup) } // Cache is now at 128 again (items from 2 to 129). Adding 1 more will trigger cleanup. - EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL); + EXPECT_EQ(tokenCache.m_cache.size(), 128UL); // Now let's lock some of the items for reading, and some for writing. Cleanup should not block on // token release, but will simply move on, without doing anything to the ones that were locked. // Out of 4 locked, two are expired, so they should get cleared under normla circumstances, but // this time they will remain in the cache. std::shared_lock readLockForUnexpired( - TokenCache::Internals::Cache[{"2", "2", "2", "2"}]->ElementMutex); + tokenCache.m_cache["2"]->ElementMutex); std::shared_lock readLockForExpired( - TokenCache::Internals::Cache[{"127", "127", "127", "127"}]->ElementMutex); + tokenCache.m_cache["127"]->ElementMutex); std::unique_lock writeLockForUnexpired( - TokenCache::Internals::Cache[{"3", "3", "3", "3"}]->ElementMutex); + tokenCache.m_cache["3"]->ElementMutex); std::unique_lock writeLockForExpired( - TokenCache::Internals::Cache[{"128", "128", "128", "128"}]->ElementMutex); + tokenCache.m_cache["128"]->ElementMutex); // Count is at 128. Inserting the 129th element, and it will trigger cleanup. - static_cast(TokenCache::GetToken("1", "1", "1", "1", 2min, [=]() { + static_cast(tokenCache.GetToken("1", 2min, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow; @@ -315,57 +300,53 @@ TEST(TokenCache, ExpiredCleanup) })); // These should be 20 unexpired items + two that are expired but were locked, so 22 total. - EXPECT_EQ(TokenCache::Internals::Cache.size(), 22UL); + EXPECT_EQ(tokenCache.m_cache.size(), 22UL); for (auto i = 1; i <= 20; ++i) { auto const n = std::to_string(i); - EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } - EXPECT_NE( - TokenCache::Internals::Cache.find({"127", "127", "127", "127"}), - TokenCache::Internals::Cache.end()); + EXPECT_NE(tokenCache.m_cache.find("127"), tokenCache.m_cache.end()); - EXPECT_NE( - TokenCache::Internals::Cache.find({"128", "128", "128", "128"}), - TokenCache::Internals::Cache.end()); + EXPECT_NE(tokenCache.m_cache.find("128"), tokenCache.m_cache.end()); for (auto i = 21; i <= 126; ++i) { auto const n = std::to_string(i); - EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } } TEST(TokenCache, MinimumExpiration) { - TokenCache::Clear(); + TestableTokenCache tokenCache; - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token1.ExpiresOn, Tomorrow); EXPECT_EQ(token1.Token, "T1"); - auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 24h, [=]() { + auto const token2 = tokenCache.GetToken("A", 24h, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow + 1h; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); EXPECT_EQ(token2.Token, "T2"); @@ -373,34 +354,33 @@ TEST(TokenCache, MinimumExpiration) TEST(TokenCache, MultithreadedAccess) { - TokenCache::Clear(); + TestableTokenCache tokenCache; - EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token1.ExpiresOn, Tomorrow); EXPECT_EQ(token1.Token, "T1"); { - std::shared_lock itemReadLock( - TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex); + std::shared_lock itemReadLock(tokenCache.m_cache["A"]->ElementMutex); { - std::shared_lock cacheReadLock(TokenCache::Internals::CacheMutex); + std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); // Parallel threads read both the container and the item we're accessing, and we can access it // in parallel as well. - auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + auto const token2 = tokenCache.GetToken("A", 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); AccessToken result; result.Token = "T2"; @@ -408,7 +388,7 @@ TEST(TokenCache, MultithreadedAccess) return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); EXPECT_EQ(token2.Token, token1.Token); @@ -416,33 +396,32 @@ TEST(TokenCache, MultithreadedAccess) // The cache is unlocked, but one item is being read in a parallel thread, which does not // prevent new items (with different key) from being appended to cache. - auto const token3 = TokenCache::GetToken("E", "F", "G", "H", 2min, [=]() { + auto const token3 = tokenCache.GetToken("B", 2min, [=]() { AccessToken result; result.Token = "T3"; result.ExpiresOn = Tomorrow + 2h; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 2UL); + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); EXPECT_EQ(token3.Token, "T3"); } { - std::unique_lock itemWriteLock( - TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex); + std::unique_lock itemWriteLock(tokenCache.m_cache["A"]->ElementMutex); // The cache is unlocked, but one item is being written in a parallel thread, which does not // prevent new items (with different key) from being appended to cache. - auto const token3 = TokenCache::GetToken("I", "J", "K", "L", 2min, [=]() { + auto const token3 = tokenCache.GetToken("C", 2min, [=]() { AccessToken result; result.Token = "T4"; result.ExpiresOn = Tomorrow + 3h; return result; }); - EXPECT_EQ(TokenCache::Internals::Cache.size(), 3UL); + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); EXPECT_EQ(token3.Token, "T4"); diff --git a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp index 11bed803f..6189fb98e 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp @@ -3,8 +3,6 @@ #include "private/token_credential_impl.hpp" -#include "private/token_cache.hpp" - #include "credential_test_helper.hpp" #include @@ -20,7 +18,6 @@ using Azure::Core::Credentials::TokenCredential; using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Http::HttpMethod; -using Azure::Identity::_detail::TokenCache; using Azure::Identity::_detail::TokenCredentialImpl; using Azure::Identity::Test::_detail::CredentialTestHelper; @@ -53,8 +50,6 @@ public: AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const& context) const override { - TokenCache::Clear(); - return m_tokenCredentialImpl->GetToken(context, [&]() { m_throwingFunction(); @@ -80,9 +75,9 @@ TEST(TokenCredentialImpl, Normal) return std::make_unique( HttpMethod::Delete, Url("https://outlook.com/"), options); }, - {{{"https://azure.com/.default", "https://microsoft.com/.default"}}, - {{"https://azure.com/.default", "https://microsoft.com/.default"}}, - {{"https://azure.com/.default", "https://microsoft.com/.default"}}}, + {{"https://azure.com/.default", "https://microsoft.com/.default"}, + {"https://azure.com/.default", "https://microsoft.com/.default"}, + {"https://azure.com/.default", "https://microsoft.com/.default"}}, std::vector{ "{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}", "{\"access_token\":\"ACCESSTOKEN2\", \"expires_in\":7200}", @@ -157,7 +152,7 @@ TEST(TokenCredentialImpl, StdException) return std::make_unique( []() { throw std::exception(); }, options); }, - {{{"https://azure.com/.default", "https://microsoft.com/.default"}}}, + {{"https://azure.com/.default", "https://microsoft.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -176,7 +171,7 @@ TEST(TokenCredentialImpl, ThrowInt) return std::make_unique([]() { throw 0; }, options); }, - {{{"https://azure.com/.default", "https://microsoft.com/.default"}}}, + {{"https://azure.com/.default", "https://microsoft.com/.default"}}, {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -310,7 +305,7 @@ TEST(TokenCredentialImpl, NoExpiration) return std::make_unique( HttpMethod::Delete, Url("https://outlook.com/"), options); }, - {{{"https://azure.com/.default", "https://microsoft.com/.default"}}}, + {{"https://azure.com/.default", "https://microsoft.com/.default"}}, {"{\"access_token\":\"ACCESSTOKEN\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -330,7 +325,7 @@ TEST(TokenCredentialImpl, NoToken) return std::make_unique( HttpMethod::Delete, Url("https://outlook.com/"), options); }, - {{{"https://azure.com/.default", "https://microsoft.com/.default"}}}, + {{"https://azure.com/.default", "https://microsoft.com/.default"}}, {"{\"expires_in\":3600}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; @@ -369,7 +364,7 @@ TEST(TokenCredentialImpl, NullResponse) return std::make_unique( HttpMethod::Delete, Url("https://microsoft.com/"), options); }, - {{{"https://azure.com/.default"}}}, + {{"https://azure.com/.default"}}, {{"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN\"}"}}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; diff --git a/sdk/identity/azure-identity/test/ut/token_credential_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_test.cpp index 5dca4c58c..bfe57d7c0 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_test.cpp @@ -7,8 +7,6 @@ #include #include -#include "private/token_cache.hpp" - #include #include @@ -60,10 +58,12 @@ TEST_F(TokenCredentialTest, ClientSecret) std::string const testName(GetTestName()); auto const clientSecretCredential = GetClientSecretCredential(testName); - _detail::TokenCache::Clear(); + Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://vault.azure.net/.default"}; + tokenRequestContext.MinimumExpiration = std::chrono::hours(1000000); auto const token = clientSecretCredential->GetToken( - {{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext); + tokenRequestContext, Azure::Core::Context::ApplicationContext); EXPECT_FALSE(token.Token.empty()); EXPECT_GE(token.ExpiresOn, std::chrono::system_clock::now()); @@ -74,10 +74,12 @@ TEST_F(TokenCredentialTest, EnvironmentCredential) std::string const testName(GetTestName()); auto const clientSecretCredential = GetEnvironmentCredential(testName); - _detail::TokenCache::Clear(); + Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://vault.azure.net/.default"}; + tokenRequestContext.MinimumExpiration = std::chrono::hours(1000000); auto const token = clientSecretCredential->GetToken( - {{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext); + tokenRequestContext, Azure::Core::Context::ApplicationContext); EXPECT_FALSE(token.Token.empty()); EXPECT_GE(token.ExpiresOn, std::chrono::system_clock::now());