In-memory Identity token cache (#4024)

This commit is contained in:
Anton Kolesnyk 2022-10-25 13:23:22 -07:00 committed by GitHub
parent 34485a7ab7
commit 4de2423934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1034 additions and 119 deletions

View File

@ -10,6 +10,8 @@
### Other Changes
- Added support for Identity token caching, and for configuring token refresh offset in `BearerTokenAuthenticationPolicy`.
## 1.8.0-beta.1 (2022-10-06)
### Features Added

View File

@ -48,6 +48,12 @@ namespace Azure { namespace Core { namespace Credentials {
*
*/
std::vector<std::string> Scopes;
/**
* @brief Minimum token expiration suggestion.
*
*/
DateTime::duration MinimumExpiration = std::chrono::minutes(2);
};
/**
@ -61,6 +67,8 @@ namespace Azure { namespace Core { namespace Credentials {
* @param tokenRequestContext A context to get the token in.
* @param context A context to control the request lifetime.
*
* @return Authentication token.
*
* @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred.
*/
virtual AccessToken GetToken(

View File

@ -18,8 +18,8 @@ std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
{
std::lock_guard<std::mutex> lock(m_accessTokenMutex);
// Refresh the token in 2 or less minutes before the actual expiration.
if (std::chrono::system_clock::now() > (m_accessToken.ExpiresOn - std::chrono::minutes(2)))
if (std::chrono::system_clock::now()
> (m_accessToken.ExpiresOn - m_tokenRequestContext.MinimumExpiration))
{
m_accessToken = m_credential->GetToken(m_tokenRequestContext, context);
}

View File

@ -4,6 +4,8 @@
### Features Added
- Added token caching.
### Breaking Changes
### Bugs Fixed

View File

@ -60,6 +60,8 @@ set(
AZURE_IDENTITY_SOURCE
src/private/managed_identity_source.hpp
src/private/package_version.hpp
src/private/token_cache.hpp
src/private/token_cache_internals.hpp
src/private/token_credential_impl.hpp
src/chained_token_credential.cpp
src/client_certificate_credential.cpp
@ -67,6 +69,7 @@ set(
src/environment_credential.cpp
src/managed_identity_credential.cpp
src/managed_identity_source.cpp
src/token_cache.cpp
src/token_credential_impl.cpp
)
@ -106,6 +109,9 @@ az_rtti_setup(
)
if(BUILD_TESTING)
# define a symbol that enables some test hooks in code
add_compile_definitions(TESTING_BUILD)
# tests
if (NOT AZ_ALL_LIBRARIES OR FETCH_SOURCE_DEPS)
include(AddGoogleTest)

View File

@ -51,13 +51,18 @@ namespace Azure { namespace Identity {
std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl;
Core::Url m_requestUrl;
std::string m_requestBody;
std::string m_tenantId;
std::string m_clientId;
std::string m_authorityHost;
bool m_isAdfs;
ClientSecretCredential(
std::string const& tenantId,
std::string const& clientId,
std::string tenantId,
std::string clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::string authorityHost,
Core::Credentials::TokenCredentialOptions const& options);
public:
@ -70,8 +75,8 @@ namespace Azure { namespace Identity {
* @param options Options for token retrieval.
*/
explicit ClientSecretCredential(
std::string const& tenantId,
std::string const& clientId,
std::string tenantId,
std::string clientId,
std::string const& clientSecret,
ClientSecretCredentialOptions const& options);
@ -86,7 +91,7 @@ namespace Azure { namespace Identity {
explicit ClientSecretCredential(
std::string tenantId,
std::string clientId,
std::string clientSecret,
std::string const& clientSecret,
Core::Credentials::TokenCredentialOptions const& options
= Core::Credentials::TokenCredentialOptions());
@ -102,6 +107,8 @@ namespace Azure { namespace Identity {
* @param tokenRequestContext A context to get the token in.
* @param context A context to control the request lifetime.
*
* @return Authentication token.
*
* @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred.
*/
Core::Credentials::AccessToken GetToken(

View File

@ -59,6 +59,8 @@ namespace Azure { namespace Identity {
* @param tokenRequestContext A context to get the token in.
* @param context A context to control the request lifetime.
*
* @return Authentication token.
*
* @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred.
*/
Core::Credentials::AccessToken GetToken(

View File

@ -3,9 +3,11 @@
#include "azure/identity/client_secret_credential.hpp"
#include "private/token_cache.hpp"
#include "private/token_credential_impl.hpp"
#include <sstream>
#include <utility>
using namespace Azure::Identity;
@ -13,29 +15,33 @@ std::string const Azure::Identity::_detail::g_aadGlobalAuthority
= "https://login.microsoftonline.com/";
ClientSecretCredential::ClientSecretCredential(
std::string const& tenantId,
std::string const& clientId,
std::string tenantId,
std::string clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::string authorityHost,
Azure::Core::Credentials::TokenCredentialOptions const& options)
: m_tokenCredentialImpl(std::make_unique<_detail::TokenCredentialImpl>(options)),
m_isAdfs(tenantId == "adfs")
m_tenantId(std::move(tenantId)), m_clientId(std::move(clientId)),
m_authorityHost(std::move(authorityHost))
{
using Azure::Core::Url;
m_requestUrl = Url(authorityHost);
m_requestUrl.AppendPath(tenantId);
m_isAdfs = (m_tenantId == "adfs");
m_requestUrl = Url(m_authorityHost);
m_requestUrl.AppendPath(m_tenantId);
m_requestUrl.AppendPath(m_isAdfs ? "oauth2/token" : "oauth2/v2.0/token");
std::ostringstream body;
body << "grant_type=client_credentials&client_id=" << Url::Encode(clientId)
body << "grant_type=client_credentials&client_id=" << Url::Encode(m_clientId)
<< "&client_secret=" << Url::Encode(clientSecret);
m_requestBody = body.str();
}
ClientSecretCredential::ClientSecretCredential(
std::string const& tenantId,
std::string const& clientId,
std::string tenantId,
std::string clientId,
std::string const& clientSecret,
ClientSecretCredentialOptions const& options)
: ClientSecretCredential(tenantId, clientId, clientSecret, options.AuthorityHost, options)
@ -45,7 +51,7 @@ ClientSecretCredential::ClientSecretCredential(
ClientSecretCredential::ClientSecretCredential(
std::string tenantId,
std::string clientId,
std::string clientSecret,
std::string const& clientSecret,
Core::Credentials::TokenCredentialOptions const& options)
: ClientSecretCredential(
tenantId,
@ -62,28 +68,49 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken(
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
Azure::Core::Context const& context) const
{
return m_tokenCredentialImpl->GetToken(context, [&]() {
using _detail::TokenCredentialImpl;
using Azure::Core::Http::HttpMethod;
using _detail::TokenCache;
using _detail::TokenCredentialImpl;
std::ostringstream body;
body << m_requestBody;
std::string scopesStr;
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
body << "&scope=" << TokenCredentialImpl::FormatScopes(scopes, m_isAdfs);
}
scopesStr = TokenCredentialImpl::FormatScopes(scopes, m_isAdfs);
}
}
auto request = std::make_unique<TokenCredentialImpl::TokenRequest>(
HttpMethod::Post, m_requestUrl, body.str());
// TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return TokenCache::GetToken(
m_tenantId,
m_clientId,
m_authorityHost,
scopesStr,
tokenRequestContext.MinimumExpiration,
[&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
using Azure::Core::Http::HttpMethod;
if (m_isAdfs)
{
request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost());
}
std::ostringstream body;
body << m_requestBody;
return request;
});
if (!scopesStr.empty())
{
body << "&scope=" << scopesStr;
}
auto request = std::make_unique<TokenCredentialImpl::TokenRequest>(
HttpMethod::Post, m_requestUrl, body.str());
if (m_isAdfs)
{
request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost());
}
return request;
});
});
}

View File

@ -3,6 +3,8 @@
#include "private/managed_identity_source.hpp"
#include "private/token_cache.hpp"
#include <azure/core/internal/environment.hpp>
#include <fstream>
@ -59,7 +61,7 @@ AppServiceManagedIdentitySource::AppServiceManagedIdentitySource(
std::string const& apiVersion,
std::string const& secretHeaderName,
std::string const& clientIdHeaderName)
: ManagedIdentitySource(options),
: ManagedIdentitySource(clientId, endpointUrl.GetHost(), options),
m_request(Azure::Core::Http::HttpMethod::Get, std::move(endpointUrl))
{
{
@ -81,18 +83,37 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken(
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
Azure::Core::Context const& context) const
{
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
std::string scopesStr;
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
request->HttpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true));
}
scopesStr = TokenCredentialImpl::FormatScopes(scopes, true);
}
}
return request;
});
// TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return TokenCache::GetToken(
std::string(),
GetClientId(),
GetAuthorityHost(),
scopesStr,
tokenRequestContext.MinimumExpiration,
[&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
if (!scopesStr.empty())
{
request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr);
}
return request;
});
});
}
std::unique_ptr<ManagedIdentitySource> AppServiceV2017ManagedIdentitySource::Create(
@ -128,7 +149,7 @@ CloudShellManagedIdentitySource::CloudShellManagedIdentitySource(
std::string const& clientId,
Azure::Core::Credentials::TokenCredentialOptions const& options,
Azure::Core::Url endpointUrl)
: ManagedIdentitySource(options), m_url(std::move(endpointUrl))
: ManagedIdentitySource(clientId, endpointUrl.GetHost(), options), m_url(std::move(endpointUrl))
{
using Azure::Core::Url;
if (!clientId.empty())
@ -141,28 +162,47 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken(
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
Azure::Core::Context const& context) const
{
return TokenCredentialImpl::GetToken(context, [&]() {
using Azure::Core::Url;
using Azure::Core::Http::HttpMethod;
std::string resource;
std::string scopesStr;
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
resource = "resource=" + FormatScopes(scopes, true);
if (!m_body.empty())
{
resource += "&";
}
}
scopesStr = TokenCredentialImpl::FormatScopes(scopes, true);
}
}
auto request = std::make_unique<TokenRequest>(HttpMethod::Post, m_url, resource + m_body);
request->HttpRequest.SetHeader("Metadata", "true");
// TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return TokenCache::GetToken(
std::string(),
GetClientId(),
GetAuthorityHost(),
scopesStr,
tokenRequestContext.MinimumExpiration,
[&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
using Azure::Core::Url;
using Azure::Core::Http::HttpMethod;
return request;
});
std::string resource;
if (!scopesStr.empty())
{
resource = "resource=" + scopesStr;
if (!m_body.empty())
{
resource += "&";
}
}
auto request = std::make_unique<TokenRequest>(HttpMethod::Post, m_url, resource + m_body);
request->HttpRequest.SetHeader("Metadata", "true");
return request;
});
});
}
std::unique_ptr<ManagedIdentitySource> AzureArcManagedIdentitySource::Create(
@ -194,7 +234,8 @@ std::unique_ptr<ManagedIdentitySource> AzureArcManagedIdentitySource::Create(
AzureArcManagedIdentitySource::AzureArcManagedIdentitySource(
Azure::Core::Credentials::TokenCredentialOptions const& options,
Azure::Core::Url endpointUrl)
: ManagedIdentitySource(options), m_url(std::move(endpointUrl))
: ManagedIdentitySource(std::string(), endpointUrl.GetHost(), options),
m_url(std::move(endpointUrl))
{
m_url.AppendQueryParameter("api-version", "2019-11-01");
@ -204,6 +245,15 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken(
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
Azure::Core::Context const& context) const
{
std::string scopesStr;
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
scopesStr = TokenCredentialImpl::FormatScopes(scopes, true);
}
}
auto const createRequest = [&]() {
using Azure::Core::Http::HttpMethod;
using Azure::Core::Http::Request;
@ -212,59 +262,70 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken(
{
auto& httpRequest = request->HttpRequest;
httpRequest.SetHeader("Metadata", "true");
if (!scopesStr.empty())
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
httpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true));
}
httpRequest.GetUrl().AppendQueryParameter("resource", scopesStr);
}
}
return request;
};
return TokenCredentialImpl::GetToken(
context,
createRequest,
[&](auto const statusCode, auto const& response) -> std::unique_ptr<TokenRequest> {
using Core::Credentials::AuthenticationException;
using Core::Http::HttpStatusCode;
// TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return TokenCache::GetToken(
std::string(),
GetClientId(),
GetAuthorityHost(),
scopesStr,
tokenRequestContext.MinimumExpiration,
[&]() {
return TokenCredentialImpl::GetToken(
context,
createRequest,
[&](auto const statusCode, auto const& response) -> std::unique_ptr<TokenRequest> {
using Core::Credentials::AuthenticationException;
using Core::Http::HttpStatusCode;
if (statusCode != HttpStatusCode::Unauthorized)
{
return nullptr;
}
if (statusCode != HttpStatusCode::Unauthorized)
{
return nullptr;
}
auto const& headers = response.GetHeaders();
auto authHeader = headers.find("WWW-Authenticate");
if (authHeader == headers.end())
{
throw AuthenticationException(
"Did not receive expected WWW-Authenticate header "
"in the response from Azure Arc Managed Identity Endpoint.");
}
auto const& headers = response.GetHeaders();
auto authHeader = headers.find("WWW-Authenticate");
if (authHeader == headers.end())
{
throw AuthenticationException(
"Did not receive expected WWW-Authenticate header "
"in the response from Azure Arc Managed Identity Endpoint.");
}
constexpr auto ChallengeValueSeparator = '=';
auto const& challenge = authHeader->second;
auto eq = challenge.find(ChallengeValueSeparator);
if (eq == std::string::npos
|| challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos)
{
throw AuthenticationException("The WWW-Authenticate header in the response from Azure "
"Arc Managed Identity Endpoint "
"did not match the expected format.");
}
constexpr auto ChallengeValueSeparator = '=';
auto const& challenge = authHeader->second;
auto eq = challenge.find(ChallengeValueSeparator);
if (eq == std::string::npos
|| challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos)
{
throw AuthenticationException(
"The WWW-Authenticate header in the response from Azure Arc "
"Managed Identity Endpoint did not match the expected format.");
}
auto request = createRequest();
std::ifstream secretFile(challenge.substr(eq + 1));
request->HttpRequest.SetHeader(
"Authorization",
"Basic "
+ std::string(
std::istreambuf_iterator<char>(secretFile), std::istreambuf_iterator<char>()));
auto request = createRequest();
std::ifstream secretFile(challenge.substr(eq + 1));
request->HttpRequest.SetHeader(
"Authorization",
"Basic "
+ std::string(
std::istreambuf_iterator<char>(secretFile),
std::istreambuf_iterator<char>()));
return request;
return request;
});
});
}
@ -278,7 +339,7 @@ std::unique_ptr<ManagedIdentitySource> ImdsManagedIdentitySource::Create(
ImdsManagedIdentitySource::ImdsManagedIdentitySource(
std::string const& clientId,
Azure::Core::Credentials::TokenCredentialOptions const& options)
: ManagedIdentitySource(options),
: ManagedIdentitySource(clientId, std::string(), options),
m_request(
Azure::Core::Http::HttpMethod::Get,
Azure::Core::Url("http://169.254.169.254/metadata/identity/oauth2/token"))
@ -302,16 +363,35 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
Azure::Core::Context const& context) const
{
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
std::string scopesStr;
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
request->HttpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true));
}
scopesStr = TokenCredentialImpl::FormatScopes(scopes, true);
}
}
return request;
});
// TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return TokenCache::GetToken(
std::string(),
GetClientId(),
GetAuthorityHost(),
scopesStr,
tokenRequestContext.MinimumExpiration,
[&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
if (!scopesStr.empty())
{
request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr);
}
return request;
});
});
}

View File

@ -11,9 +11,14 @@
#include <memory>
#include <string>
#include <utility>
namespace Azure { namespace Identity { namespace _detail {
class ManagedIdentitySource : protected TokenCredentialImpl {
private:
std::string m_clientId;
std::string m_authorityHost;
public:
virtual Core::Credentials::AccessToken GetToken(
Core::Credentials::TokenRequestContext const& tokenRequestContext,
@ -22,10 +27,17 @@ namespace Azure { namespace Identity { namespace _detail {
protected:
static Core::Url ParseEndpointUrl(std::string const& url, char const* envVarName);
explicit ManagedIdentitySource(Core::Credentials::TokenCredentialOptions const& options)
: TokenCredentialImpl(options)
explicit ManagedIdentitySource(
std::string clientId,
std::string authorityHost,
Core::Credentials::TokenCredentialOptions const& options)
: TokenCredentialImpl(options), m_clientId(std::move(clientId)),
m_authorityHost(std::move(authorityHost))
{
}
std::string const& GetClientId() const { return m_clientId; }
std::string const& GetAuthorityHost() const { return m_authorityHost; }
};
class AppServiceManagedIdentitySource : public ManagedIdentitySource {

View File

@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Token cache.
*/
#pragma once
#include <azure/core/credentials/credentials.hpp>
#include <azure/core/datetime.hpp>
#include <functional>
#include <string>
namespace Azure { namespace Identity { namespace _detail {
/**
* @brief Implements an access token cache.
*
*/
class TokenCache final {
TokenCache() = delete;
~TokenCache() = delete;
public:
/**
* @brief Attempts to get token from cache, and if not found, gets the token using the function
* provided, caches it, and returns its value.
*
* @param tenantId Azure Tenant ID.
* @param clientId Azure Client ID.
* @param authorityHost Authentication authority URL.
* @param scopes Authentication scopes.
*
* @return Authentication token.
*
* @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred.
*/
static Core::Credentials::AccessToken GetToken(
std::string const& tenantId,
std::string const& clientId,
std::string const& authorityHost,
std::string const& scopes,
DateTime::duration minimumExpiration,
std::function<Core::Credentials::AccessToken()> const& getNewToken);
/**
* @brief Provides access to internal aspects of the cache as a test hook.
*
*/
class Internals;
#if defined(TESTING_BUILD)
/**
* @brief Clears token cache. Intended to only be used in tests.
*
*/
static void Clear();
#endif
};
}}} // namespace Azure::Identity::_detail

View File

@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Token cache internals and test hooks.
*/
#pragma once
#include "token_cache.hpp"
#include <azure/core/credentials/credentials.hpp>
#include <functional>
#include <map>
#include <memory>
#include <shared_mutex>
#include <string>
#include <tuple>
namespace Azure { namespace Identity { namespace _detail {
/**
* @brief Implements internal aspects of token cache and provides test hooks.
*
*/
class TokenCache::Internals final {
Internals() = delete;
~Internals() = delete;
public:
/**
* @brief Represents a unique set of characteristics that are used to distinguish between cache
* entries.
*
*/
struct CacheKey final
{
std::string TenantId; ///< Tenant ID.
std::string ClientId; ///< Client ID.
std::string AuthorityHost; ///< Authority Host.
std::string Scopes; ///< Authentication Scopes as a single string.
bool operator<(TokenCache::Internals::CacheKey const& other) const
{
return std::tie(TenantId, ClientId, AuthorityHost, Scopes)
< std::tie(other.TenantId, other.ClientId, other.AuthorityHost, other.Scopes);
}
};
/**
* @brief Represents immediate cache value (token) and a synchronization primitive to handle its
* updates.
*
*/
struct CacheValue final
{
std::shared_timed_mutex ElementMutex;
Core::Credentials::AccessToken AccessToken;
};
/**
* @brief The cache itself.
*
*/
static std::map<CacheKey, std::shared_ptr<CacheValue>> Cache;
/**
* @brief Mutex to access the cache container.
*
*/
static std::shared_timed_mutex CacheMutex;
#if defined(TESTING_BUILD)
/**
* A test hook that gets invoked before cache write lock gets acquired.
*
*/
static std::function<void()> OnBeforeCacheWriteLock;
/**
* A test hook that gets invoked before item write lock gets acquired.
*
*/
static std::function<void()> OnBeforeItemWriteLock;
#endif
};
}}} // namespace Azure::Identity::_detail

View File

@ -0,0 +1,154 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "private/token_cache.hpp"
#include "private/token_cache_internals.hpp"
#include <chrono>
#include <mutex>
using Azure::Identity::_detail::TokenCache;
using Azure::DateTime;
using Azure::Core::Credentials::AccessToken;
decltype(TokenCache::Internals::Cache) TokenCache::Internals::Cache;
decltype(TokenCache::Internals::CacheMutex) TokenCache::Internals::CacheMutex;
#if defined(TESTING_BUILD)
std::function<void()> TokenCache::Internals::OnBeforeCacheWriteLock;
std::function<void()> TokenCache::Internals::OnBeforeItemWriteLock;
#endif
namespace {
bool IsFresh(
std::shared_ptr<TokenCache::Internals::CacheValue> const& item,
DateTime::duration minimumExpiration,
std::chrono::system_clock::time_point now)
{
return item->AccessToken.ExpiresOn > (DateTime(now) + minimumExpiration);
}
std::shared_ptr<TokenCache::Internals::CacheValue> GetOrCreateValue(
TokenCache::Internals::CacheKey const& key,
DateTime::duration minimumExpiration)
{
{
std::shared_lock<std::shared_timed_mutex> cacheReadLock(TokenCache::Internals::CacheMutex);
auto const found = TokenCache::Internals::Cache.find(key);
if (found != TokenCache::Internals::Cache.end())
{
return found->second;
}
}
#if defined(TESTING_BUILD)
if (TokenCache::Internals::OnBeforeCacheWriteLock != nullptr)
{
TokenCache::Internals::OnBeforeCacheWriteLock();
}
#endif
std::unique_lock<std::shared_timed_mutex> cacheWriteLock(TokenCache::Internals::CacheMutex);
// Search cache for the second time, in case the item was inserted between releasing the read lock
// and acquiring the write lock.
auto const found = TokenCache::Internals::Cache.find(key);
if (found != TokenCache::Internals::Cache.end())
{
return found->second;
}
// Clean up cache from expired items (once every N insertions).
{
auto const cacheSize = TokenCache::Internals::Cache.size();
// N: cacheSize (before insertion) is >= 32 and is a power of two.
// 32 as a starting point does not have any special meaning.
//
// Power of 2 trick:
// https://www.exploringbinary.com/ten-ways-to-check-if-an-integer-is-a-power-of-two-in-c/
if (cacheSize >= 32 && (cacheSize & (cacheSize - 1)) == 0)
{
auto now = std::chrono::system_clock::now();
auto iter = TokenCache::Internals::Cache.begin();
while (iter != TokenCache::Internals::Cache.end())
{
// Should we end up erasing the element, iterator to current will become invalid, after
// which we can't increment it. So we copy current, and safely advance the loop iterator.
auto const curr = iter;
++iter;
// We will try to obtain a write lock, but in a non-blocking way. We only lock it if no one
// was holding it for read and write at a time. If it's busy in any way, we don't wait, but
// move on.
auto const item = curr->second;
{
std::unique_lock<std::shared_timed_mutex> lock(item->ElementMutex, std::defer_lock);
if (lock.try_lock() && !IsFresh(item, minimumExpiration, now))
{
TokenCache::Internals::Cache.erase(curr);
}
}
}
}
}
// Insert the blank value value and return it.
return TokenCache::Internals::Cache[key] = std::make_shared<TokenCache::Internals::CacheValue>();
}
} // namespace
AccessToken TokenCache::GetToken(
std::string const& tenantId,
std::string const& clientId,
std::string const& authorityHost,
std::string const& scopes,
DateTime::duration minimumExpiration,
std::function<AccessToken()> const& getNewToken)
{
auto const item
= GetOrCreateValue({tenantId, clientId, authorityHost, scopes}, minimumExpiration);
{
std::shared_lock<std::shared_timed_mutex> itemReadLock(item->ElementMutex);
if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now()))
{
return item->AccessToken;
}
}
{
#if defined(TESTING_BUILD)
if (TokenCache::Internals::OnBeforeItemWriteLock != nullptr)
{
TokenCache::Internals::OnBeforeItemWriteLock();
}
#endif
std::unique_lock<std::shared_timed_mutex> itemWriteLock(item->ElementMutex);
// Check the expiration for the second time, in case it just got updated, after releasing the
// itemReadLock, and before acquiring itemWriteLock.
if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now()))
{
return item->AccessToken;
}
auto const newToken = getNewToken();
item->AccessToken = newToken;
return newToken;
}
}
#if defined(TESTING_BUILD)
void TokenCache::Clear()
{
std::unique_lock<std::shared_timed_mutex> cacheWriteLock(TokenCache::Internals::CacheMutex);
Internals::Cache.clear();
}
#endif

View File

@ -25,6 +25,7 @@ add_executable (
macro_guard_test.cpp
managed_identity_credential_test.cpp
simplified_header_test.cpp
token_cache_test.cpp
token_credential_impl_test.cpp
token_credential_test.cpp
)

View File

@ -3,6 +3,7 @@
#include "credential_test_helper.hpp"
#include "private/token_cache_internals.hpp"
#include <azure/core/internal/environment.hpp>
#include <stdlib.h>
@ -70,6 +71,8 @@ CredentialTestHelper::TokenRequestSimulationResult CredentialTestHelper::Simulat
std::vector<TokenRequestSimulationServerResponse> const& responses,
GetTokenCallback getToken)
{
Azure::Identity::_detail::TokenCache::Clear();
using Azure::Core::Context;
using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;

View File

@ -0,0 +1,450 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "private/token_cache_internals.hpp"
#include <mutex>
#include <gtest/gtest.h>
using Azure::DateTime;
using Azure::Core::Credentials::AccessToken;
using Azure::Identity::_detail::TokenCache;
using namespace std::chrono_literals;
TEST(TokenCache, KeyComparison)
{
using Key = TokenCache::Internals::CacheKey;
Key const key1{"a", "b", "c", "d"};
EXPECT_FALSE(key1 < key1);
{
Key const key1dup{"a", "b", "c", "d"};
EXPECT_FALSE(key1 < key1dup);
EXPECT_FALSE(key1dup < key1);
}
Key const key2{"a", "b", "c", "~"};
Key const key3{"a", "b", "~", "d"};
Key const key4{"a", "~", "c", "d"};
Key const key5{"~", "b", "c", "d"};
EXPECT_TRUE(key1 < key2);
EXPECT_TRUE(key1 < key3);
EXPECT_TRUE(key1 < key4);
EXPECT_TRUE(key1 < key5);
EXPECT_FALSE(key2 < key1);
EXPECT_FALSE(key3 < key1);
EXPECT_FALSE(key4 < key1);
EXPECT_FALSE(key5 < key1);
EXPECT_TRUE(key2 < key3);
EXPECT_TRUE(key2 < key4);
EXPECT_TRUE(key2 < key5);
EXPECT_FALSE(key3 < key2);
EXPECT_FALSE(key4 < key2);
EXPECT_FALSE(key5 < key2);
EXPECT_TRUE(key3 < key4);
EXPECT_TRUE(key3 < key5);
EXPECT_FALSE(key4 < key3);
EXPECT_FALSE(key5 < key3);
EXPECT_TRUE(key4 < key5);
EXPECT_FALSE(key5 < key4);
}
TEST(TokenCache, GetReuseRefresh)
{
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const Yesterday = Tomorrow - 48h;
{
auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token1.ExpiresOn, Tomorrow);
EXPECT_EQ(token1.Token, "T1");
auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 24h;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn);
EXPECT_EQ(token1.Token, token2.Token);
}
{
TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->AccessToken.ExpiresOn = Yesterday;
auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T3";
result.ExpiresOn = Tomorrow + 1min;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min);
EXPECT_EQ(token.Token, "T3");
}
}
TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey)
{
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
TokenCache::Internals::OnBeforeCacheWriteLock = [=]() {
TokenCache::Internals::OnBeforeCacheWriteLock = nullptr;
static_cast<void>(TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
}));
};
auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before "
"acquiring cache write lock");
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 1min;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "T1");
}
TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken)
{
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const Yesterday = Tomorrow - 48h;
{
TokenCache::Internals::OnBeforeItemWriteLock = [=]() {
TokenCache::Internals::OnBeforeItemWriteLock = nullptr;
auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}];
item->AccessToken.Token = "T1";
item->AccessToken.ExpiresOn = Tomorrow;
};
auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before "
"acquiring item write lock");
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 1min;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "T1");
}
// Same as above, but the token that was inserted is already expired.
{
TokenCache::Clear();
TokenCache::Internals::OnBeforeItemWriteLock = [=]() {
TokenCache::Internals::OnBeforeItemWriteLock = nullptr;
auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}];
item->AccessToken.Token = "T3";
item->AccessToken.ExpiresOn = Yesterday;
};
auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T4";
result.ExpiresOn = Tomorrow + 3min;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min);
EXPECT_EQ(token.Token, "T4");
}
}
TEST(TokenCache, ExpiredCleanup)
{
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const Yesterday = Tomorrow - 48h;
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
for (auto i = 1; i <= 65; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(TokenCache::GetToken(n, n, n, n, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
}));
}
// Simply: we added 64+1 token, none of them has expired. None are expected to be cleaned up.
EXPECT_EQ(TokenCache::Internals::Cache.size(), 65UL);
// Let's expire 3 of them, with numbers from 1 to 3.
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday;
}
// Add tokens up to 128 total. When 129th gets added, clean up should get triggered.
for (auto i = 66; i <= 128; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(TokenCache::GetToken(n, n, n, n, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
}));
}
EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL);
// Count is at 128. Tokens from 1 to 3 are still in cache even though they are expired.
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end());
}
// One more addition to the cache and cleanup for the expired ones will get triggered.
static_cast<void>(TokenCache::GetToken("129", "129", "129", "129", 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
}));
// We were at 128 before we added 1 more, and now we're at 126. 3 were deleted, 1 was added.
EXPECT_EQ(TokenCache::Internals::Cache.size(), 126UL);
// Items from 1 to 3 should no longer be in the cache.
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end());
}
// Let's expire items from 21 all the way up to 129.
for (auto i = 21; i <= 129; ++i)
{
auto const n = std::to_string(i);
TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday;
}
// Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to
// 128 items (with numbers from 2 to 129, and number 1 missing).
for (auto i = 2; i <= 3; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(TokenCache::GetToken(n, n, n, n, 2min, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow;
return result;
}));
}
// Cache is now at 128 again (items from 2 to 129). Adding 1 more will trigger cleanup.
EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL);
// Now let's lock some of the items for reading, and some for writing. Cleanup should not block on
// token release, but will simply move on, without doing anything to the ones that were locked.
// Out of 4 locked, two are expired, so they should get cleared under normla circumstances, but
// this time they will remain in the cache.
std::shared_lock<std::shared_timed_mutex> readLockForUnexpired(
TokenCache::Internals::Cache[{"2", "2", "2", "2"}]->ElementMutex);
std::shared_lock<std::shared_timed_mutex> readLockForExpired(
TokenCache::Internals::Cache[{"127", "127", "127", "127"}]->ElementMutex);
std::unique_lock<std::shared_timed_mutex> writeLockForUnexpired(
TokenCache::Internals::Cache[{"3", "3", "3", "3"}]->ElementMutex);
std::unique_lock<std::shared_timed_mutex> writeLockForExpired(
TokenCache::Internals::Cache[{"128", "128", "128", "128"}]->ElementMutex);
// Count is at 128. Inserting the 129th element, and it will trigger cleanup.
static_cast<void>(TokenCache::GetToken("1", "1", "1", "1", 2min, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow;
return result;
}));
// These should be 20 unexpired items + two that are expired but were locked, so 22 total.
EXPECT_EQ(TokenCache::Internals::Cache.size(), 22UL);
for (auto i = 1; i <= 20; ++i)
{
auto const n = std::to_string(i);
EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end());
}
EXPECT_NE(
TokenCache::Internals::Cache.find({"127", "127", "127", "127"}),
TokenCache::Internals::Cache.end());
EXPECT_NE(
TokenCache::Internals::Cache.find({"128", "128", "128", "128"}),
TokenCache::Internals::Cache.end());
for (auto i = 21; i <= 126; ++i)
{
auto const n = std::to_string(i);
EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end());
}
}
TEST(TokenCache, MinimumExpiration)
{
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token1.ExpiresOn, Tomorrow);
EXPECT_EQ(token1.Token, "T1");
auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 24h, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 1h;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h);
EXPECT_EQ(token2.Token, "T2");
}
TEST(TokenCache, MultithreadedAccess)
{
TokenCache::Clear();
EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token1.ExpiresOn, Tomorrow);
EXPECT_EQ(token1.Token, "T1");
{
std::shared_lock<std::shared_timed_mutex> itemReadLock(
TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex);
{
std::shared_lock<std::shared_timed_mutex> cacheReadLock(TokenCache::Internals::CacheMutex);
// Parallel threads read both the container and the item we're accessing, and we can access it
// in parallel as well.
auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 1h;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL);
EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn);
EXPECT_EQ(token2.Token, token1.Token);
}
// The cache is unlocked, but one item is being read in a parallel thread, which does not
// prevent new items (with different key) from being appended to cache.
auto const token3 = TokenCache::GetToken("E", "F", "G", "H", 2min, [=]() {
AccessToken result;
result.Token = "T3";
result.ExpiresOn = Tomorrow + 2h;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 2UL);
EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h);
EXPECT_EQ(token3.Token, "T3");
}
{
std::unique_lock<std::shared_timed_mutex> itemWriteLock(
TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex);
// The cache is unlocked, but one item is being written in a parallel thread, which does not
// prevent new items (with different key) from being appended to cache.
auto const token3 = TokenCache::GetToken("I", "J", "K", "L", 2min, [=]() {
AccessToken result;
result.Token = "T4";
result.ExpiresOn = Tomorrow + 3h;
return result;
});
EXPECT_EQ(TokenCache::Internals::Cache.size(), 3UL);
EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h);
EXPECT_EQ(token3.Token, "T4");
}
}

View File

@ -3,6 +3,8 @@
#include "private/token_credential_impl.hpp"
#include "private/token_cache.hpp"
#include "credential_test_helper.hpp"
#include <memory>
@ -18,6 +20,7 @@ using Azure::Core::Credentials::TokenCredential;
using Azure::Core::Credentials::TokenCredentialOptions;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Http::HttpMethod;
using Azure::Identity::_detail::TokenCache;
using Azure::Identity::_detail::TokenCredentialImpl;
using Azure::Identity::Test::_detail::CredentialTestHelper;
@ -50,6 +53,8 @@ public:
AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const& context)
const override
{
TokenCache::Clear();
return m_tokenCredentialImpl->GetToken(context, [&]() {
m_throwingFunction();

View File

@ -7,6 +7,8 @@
#include <azure/identity/client_secret_credential.hpp>
#include <azure/identity/environment_credential.hpp>
#include "private/token_cache.hpp"
#include <chrono>
#include <thread>
@ -58,6 +60,8 @@ TEST_F(TokenCredentialTest, ClientSecret)
std::string const testName(GetTestName());
auto const clientSecretCredential = GetClientSecretCredential(testName);
_detail::TokenCache::Clear();
auto const token = clientSecretCredential->GetToken(
{{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext);
@ -70,6 +74,8 @@ TEST_F(TokenCredentialTest, EnvironmentCredential)
std::string const testName(GetTestName());
auto const clientSecretCredential = GetEnvironmentCredential(testName);
_detail::TokenCache::Clear();
auto const token = clientSecretCredential->GetToken(
{{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext);