Identity: IMDS fail-fast and Cred order change in DAC (and Core support) (#6573)

* Identity: IMDS fail-fast and Cred order change in DAC (and Core support)

* Mac fix and cspell update

* Update unit test and clang-format

* Temporarily update samples to use AzCliCred until recordings are re-recorded

* Revert samples back to use DAC

* Remove SAS auth from Tables template

* Clang-format

* Add support for 'AZURE_POD_IDENTITY_AUTHORITY_HOST', override it for running samples in CI

* Add unit test for AZURE_POD_IDENTITY_AUTHORITY_HOST

* "in milliseconds"

Co-authored-by: Scott Addie <10702007+scottaddie@users.noreply.github.com>

* PR Feedback

* Named constant + comment

---------

Co-authored-by: Anton Kolesnyk <antkmsft@users.noreply.github.com>
Co-authored-by: Scott Addie <10702007+scottaddie@users.noreply.github.com>
This commit is contained in:
Anton Kolesnyk 2025-05-30 16:47:50 -07:00 committed by GitHub
parent 0e04dd0c63
commit a035ee5f94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 406 additions and 82 deletions

1
.vscode/cspell.json vendored
View File

@ -27,6 +27,7 @@
"*.exe",
"*.a",
"*.lib",
"*.svg",
"*.yaml",
".github/CODEOWNERS",
".github/CODEOWNERS_baseline_errors.txt",

View File

@ -267,6 +267,9 @@ jobs:
useGlobalConfig: true
env:
${{ insert }}: ${{ parameters.EnvVars }}
# Set fake authority host to ensure Managed Identity fail for Default Azure Credential
# so "execute samples" step correctly picks up Azure CLI credential.
AZURE_POD_IDENTITY_AUTHORITY_HOST: 'FakeAuthorityHost'
- ${{ else }}:
- bash: |
@ -292,7 +295,10 @@ jobs:
displayName: "Run Samples for : ${{ parameters.ServiceDirectory }}"
condition: and(succeeded(), eq(variables['RunSamples'], '1'))
env:
${{ insert }}: ${{ parameters.EnvVars }}
${{ insert }}: ${{ parameters.EnvVars }}
# Set fake authority host to ensure Managed Identity fail for Default Azure Credential
# so "execute samples" step correctly picks up Azure CLI credential.
AZURE_POD_IDENTITY_AUTHORITY_HOST: 'FakeAuthorityHost'
# Make coverage targets (specified in coverage_targets.txt) and assemble
# coverage report

View File

@ -11,7 +11,7 @@ void Azure::Service::Client::DoSomething(const Azure::Core::Context& context) co
#if (0)
// Every client has its own scope. We use management.azure.com here as an example.
Core::Credentials::TokenRequestContext azureServiceClientContext;
azureServiceClientContext.Scopes = {"https://management.azure.com/"};
azureServiceClientContext.Scopes = {"https://management.azure.com/.default"};
auto authenticationToken = m_credential->GetToken(azureServiceClientContext, context);

View File

@ -10,6 +10,8 @@
### Bugs Fixed
- [[#4952]](https://github.com/Azure/azure-sdk-for-cpp/issues/4952) Improved HTTP Transport implementations' request timeouts to not exceed context deadlines.
### Other Changes
### Acknowledgments

View File

@ -9,6 +9,7 @@
#pragma once
#include "azure/core/context.hpp"
#include "azure/core/dll_import_export.hpp"
#include "azure/core/http/http.hpp"
#include "azure/core/http/raw_response.hpp"
@ -16,6 +17,15 @@
namespace Azure { namespace Core { namespace Http {
namespace _internal {
/**
* @brief A context key to use to pass connection timeout (`std::chrono::milliseconds`) to an
* HTTP transport.
*
*/
AZ_CORE_DLLEXPORT extern const Context::Key HttpConnectionTimeout;
} // namespace _internal
/**
* @brief Base class for all HTTP transport implementations.
*/

View File

@ -349,9 +349,20 @@ std::unique_ptr<RawResponse> CurlTransport::Send(Request& request, Context const
// Create CurlSession to perform request
Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Creating a new session.");
auto connectionTimeoutOverride = std::chrono::milliseconds{0};
{
std::chrono::milliseconds contextConnectionTimeout{0};
if (context.TryGetValue(Http::_internal::HttpConnectionTimeout, contextConnectionTimeout)
&& contextConnectionTimeout.count() > 0)
{
connectionTimeoutOverride = contextConnectionTimeout;
}
}
auto session = std::make_unique<CurlSession>(
request,
CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection(request, m_options),
CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection(
request, m_options, connectionTimeoutOverride, false),
m_options);
CURLcode performing;
@ -380,6 +391,7 @@ std::unique_ptr<RawResponse> CurlTransport::Send(Request& request, Context const
CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection(
request,
m_options,
connectionTimeoutOverride,
getConnectionOpenIntent + 1 >= _detail::RequestPoolResetAfterConnectionFailed),
m_options);
}
@ -1386,11 +1398,42 @@ size_t CurlSession::ResponseBufferParser::Parse(
}
namespace {
// Calculates the effective timeout value, based on options.ConnectionTimeout and
// connectionTimeoutOverride. Returns 0 if default.
long GetConnectionTimeout(
CurlTransportOptions const& options,
std::chrono::milliseconds connectionTimeoutOverride)
{
auto connectionTimeout
= (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout)
? options.ConnectionTimeout
: std::chrono::milliseconds{0};
if (connectionTimeoutOverride.count() > 0)
{
connectionTimeout = connectionTimeout.count() > 0
? (std::min)(connectionTimeout, connectionTimeoutOverride)
: connectionTimeoutOverride;
}
auto connectionTimeoutLong = 0;
if (connectionTimeout.count() > 0
&& connectionTimeout.count() <= std::numeric_limits<long>::max())
{
connectionTimeoutLong = static_cast<long>(connectionTimeout.count());
}
return connectionTimeoutLong;
}
// Calculate the connection key.
// The connection key is a tuple of host, proxy info, TLS info, etc. Basically any characteristics
// of the connection that should indicate that the connection shouldn't be re-used should be listed
// the connection key.
inline std::string GetConnectionKey(std::string const& host, CurlTransportOptions const& options)
inline std::string GetConnectionKey(
std::string const& host,
CurlTransportOptions const& options,
std::chrono::milliseconds connectionTimeoutOverride)
{
std::string key(host);
key.append(",");
@ -1427,12 +1470,7 @@ inline std::string GetConnectionKey(std::string const& host, CurlTransportOption
key.append("0");
#endif
key.append(",");
// using DefaultConnectionTimeout or 0 result in the same setting
key.append(
(options.ConnectionTimeout == Azure::Core::Http::_detail::DefaultConnectionTimeout
|| options.ConnectionTimeout == std::chrono::milliseconds(0))
? "0"
: std::to_string(options.ConnectionTimeout.count()));
key.append(std::to_string(GetConnectionTimeout(options, connectionTimeoutOverride)));
return key;
}
@ -2182,13 +2220,15 @@ int CurlConnection::SslCtxCallback(CURL*, void* sslctx)
std::unique_ptr<CurlNetworkConnection> CurlConnectionPool::ExtractOrCreateCurlConnection(
Request& request,
CurlTransportOptions const& options,
std::chrono::milliseconds connectionTimeoutOverride,
bool resetPool)
{
uint16_t port = request.GetUrl().GetPort();
// Generate a display name for the host being connected to
std::string const& hostDisplayName = request.GetUrl().GetScheme() + "://"
+ request.GetUrl().GetHost() + (port != 0 ? ":" + std::to_string(port) : "");
std::string const connectionKey = GetConnectionKey(hostDisplayName, options);
std::string const connectionKey
= GetConnectionKey(hostDisplayName, options, connectionTimeoutOverride);
{
decltype(CurlConnectionPool::g_curlConnectionPool
@ -2239,7 +2279,8 @@ std::unique_ptr<CurlNetworkConnection> CurlConnectionPool::ExtractOrCreateCurlCo
// No available connection for the pool for the required host. Create one
Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Spawn new connection.");
return std::make_unique<CurlConnection>(request, options, hostDisplayName, connectionKey);
return std::make_unique<CurlConnection>(
request, options, hostDisplayName, connectionKey, connectionTimeoutOverride);
}
// Move the connection back to the connection pool. Push it to the front so it becomes the
@ -2306,7 +2347,8 @@ CurlConnection::CurlConnection(
Request& request,
CurlTransportOptions const& options,
std::string const& hostDisplayName,
std::string const& connectionPropertiesKey)
std::string const& connectionPropertiesKey,
std::chrono::milliseconds connectionTimeoutOverride)
: m_connectionKey(connectionPropertiesKey)
{
m_handle = Azure::Core::_internal::UniqueHandle<CURL>(curl_easy_init());
@ -2408,15 +2450,17 @@ CurlConnection::CurlConnection(
+ std::string(curl_easy_strerror(result)));
}
if (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout)
{
if (!SetLibcurlOption(m_handle, CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result))
const long connectionTimeout = GetConnectionTimeout(options, connectionTimeoutOverride);
if (connectionTimeout > 0)
{
throw Azure::Core::Http::TransportException(
_detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName
+ ". Fail setting connect timeout to: "
+ std::to_string(options.ConnectionTimeout.count()) + " ms. "
+ std::string(curl_easy_strerror(result)));
if (!SetLibcurlOption(m_handle, CURLOPT_CONNECTTIMEOUT_MS, connectionTimeout, &result))
{
throw Azure::Core::Http::TransportException(
_detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName
+ ". Fail setting connect timeout to: " + std::to_string(connectionTimeout) + " ms. "
+ std::string(curl_easy_strerror(result)));
}
}
}

View File

@ -16,6 +16,7 @@
#include <azure/core/http/curl_transport.hpp>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <list>
#include <memory>
@ -77,6 +78,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
* @param request HTTP request to get #Azure::Core::Http::CurlNetworkConnection for.
* @param options The connection settings which includes host name and libcurl handle specific
* configuration.
* @param connectionTimeoutOverride If greater than 0, specifies the override value for the
* ConnectionTimeout value, specified in options.
* @param resetPool Request the pool to remove all current connections for the provided
* options to force the creation of a new connection.
*
@ -85,6 +88,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
std::unique_ptr<CurlNetworkConnection> ExtractOrCreateCurlConnection(
Request& request,
CurlTransportOptions const& options,
std::chrono::milliseconds connectionTimeoutOverride = std::chrono::milliseconds{0},
bool resetPool = false);
/**

View File

@ -196,6 +196,8 @@ namespace Azure { namespace Core {
* @param request Remote request
* @param options Connection options.
* @param hostDisplayName Display name for remote host, used for diagnostics.
* @param connectionTimeoutOverride If greater than 0, specifies the override value for the
* ConnectionTimeout value, specified in options.
*
* @param connectionPropertiesKey CURL connection properties key
*/
@ -203,7 +205,8 @@ namespace Azure { namespace Core {
Azure::Core::Http::Request& request,
Azure::Core::Http::CurlTransportOptions const& options,
std::string const& hostDisplayName,
std::string const& connectionPropertiesKey);
std::string const& connectionPropertiesKey,
std::chrono::milliseconds connectionTimeoutOverride);
/**
* @brief Destructor.

View File

@ -3,6 +3,7 @@
#include "azure/core/http/http.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/internal/strings.hpp"
#include "azure/core/url.hpp"
@ -12,11 +13,13 @@
using namespace Azure::Core;
using namespace Azure::Core::Http;
char const Azure::Core::Http::_internal::HttpShared::ContentType[] = "content-type";
char const Azure::Core::Http::_internal::HttpShared::ApplicationJson[] = "application/json";
char const Azure::Core::Http::_internal::HttpShared::Accept[] = "accept";
char const Azure::Core::Http::_internal::HttpShared::MsRequestId[] = "x-ms-request-id";
char const Azure::Core::Http::_internal::HttpShared::MsClientRequestId[] = "x-ms-client-request-id";
const Context::Key Http::_internal::HttpConnectionTimeout{};
char const Http::_internal::HttpShared::ContentType[] = "content-type";
char const Http::_internal::HttpShared::ApplicationJson[] = "application/json";
char const Http::_internal::HttpShared::Accept[] = "accept";
char const Http::_internal::HttpShared::MsRequestId[] = "x-ms-request-id";
char const Http::_internal::HttpShared::MsClientRequestId[] = "x-ms-client-request-id";
const HttpMethod HttpMethod::Get("GET");
const HttpMethod HttpMethod::Head("HEAD");

View File

@ -3,6 +3,7 @@
#pragma once
#include <chrono>
#include <memory>
#include <wincrypt.h>
@ -26,7 +27,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
std::unique_ptr<_detail::WinHttpRequest> CreateRequestHandle(
Azure::Core::_internal::UniqueHandle<void*> const& connectionHandle,
Azure::Core::Url const& url,
Azure::Core::Http::HttpMethod const& method);
Azure::Core::Http::HttpMethod const& method,
std::chrono::milliseconds connectionTimeout);
// Callback to allow a derived transport to extract the request handle. Used for WebSocket
// transports.

View File

@ -22,6 +22,7 @@
#include <windows.h>
#include <chrono>
#include <memory>
#include <mutex>
#pragma warning(push)
@ -179,7 +180,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
Azure::Core::Url const& url,
Azure::Core::Http::HttpMethod const& method,
PCCERT_CONTEXT tlsClientCertificate,
WinHttpTransportOptions const& options);
WinHttpTransportOptions const& options,
std::chrono::milliseconds connectionTimeout);
~WinHttpRequest();
void MarkRequestHandleClosed() { m_requestHandleClosed = true; };

View File

@ -1257,6 +1257,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
* @param method The HTTP method to use for the request.
* @param tlsClientCertificate The client certificate to use for the request.
* @param options The transport options to use for the request.
* @param connectionTimeout Connection timeout in milliseconds.
*
* @remark Note that we *cannot* use the TlsClientCertificate field in the options passed into
* this function because the creator of the associated WinHttpTransport object may have freed the
@ -1269,7 +1270,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
Azure::Core::Url const& url,
Azure::Core::Http::HttpMethod const& method,
PCCERT_CONTEXT tlsClientCertificate,
WinHttpTransportOptions const& options)
WinHttpTransportOptions const& options,
std::chrono::milliseconds connectionTimeout)
: m_expectedTlsRootCertificates(options.ExpectedTlsRootCertificates),
m_tlsClientCertificate(CertDuplicateCertificateContext(tlsClientCertificate))
{
@ -1324,6 +1326,20 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
}
}
if (connectionTimeout.count() > 0
&& connectionTimeout.count() < std::numeric_limits<ULONG>::max())
{
auto timeoutMillisecondsULong = static_cast<ULONG>(connectionTimeout.count());
if (!WinHttpSetOption(
m_requestHandle.get(),
WINHTTP_OPTION_CONNECT_TIMEOUT,
&timeoutMillisecondsULong,
sizeof(timeoutMillisecondsULong)))
{
GetErrorAndThrow("Error while setting connection timeout.");
}
}
if (!options.ProxyInformation.empty())
{
WINHTTP_PROXY_INFO proxyInfo{};
@ -1436,10 +1452,11 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
std::unique_ptr<WinHttpRequest> WinHttpTransportImpl::CreateRequestHandle(
Azure::Core::_internal::UniqueHandle<HINTERNET> const& connectionHandle,
Azure::Core::Url const& url,
Azure::Core::Http::HttpMethod const& method)
Azure::Core::Http::HttpMethod const& method,
std::chrono::milliseconds connectionTimeout)
{
auto request{std::make_unique<WinHttpRequest>(
connectionHandle, url, method, m_tlsClientCertificate.get(), m_options)};
connectionHandle, url, method, m_tlsClientCertificate.get(), m_options, connectionTimeout)};
// If we are supporting WebSockets, then let WinHTTP know that it should
// prepare to upgrade the HttpRequest to a WebSocket.
if (HasWebSocketSupport())
@ -1797,10 +1814,20 @@ namespace Azure { namespace Core { namespace Http { namespace _detail {
std::unique_ptr<RawResponse> WinHttpTransportImpl::Send(Request& request, Context const& context)
{
auto connectionTimeout = std::chrono::milliseconds{0};
{
std::chrono::milliseconds contextConnectionTimeout{0};
if (context.TryGetValue(Http::_internal::HttpConnectionTimeout, contextConnectionTimeout)
&& contextConnectionTimeout.count() > 0)
{
connectionTimeout = contextConnectionTimeout;
}
}
Azure::Core::_internal::UniqueHandle<HINTERNET> connectionHandle
= CreateConnectionHandle(request.GetUrl(), context);
std::unique_ptr<_detail::WinHttpRequest> requestHandle(
CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod()));
std::unique_ptr<_detail::WinHttpRequest> requestHandle(CreateRequestHandle(
connectionHandle, request.GetUrl(), request.GetMethod(), connectionTimeout));
requestHandle->SendRequest(request, context);
requestHandle->ReceiveResponse(context);

View File

@ -10,8 +10,13 @@
### Bugs Fixed
- [[#4952]](https://github.com/Azure/azure-sdk-for-cpp/issues/4952) Fixed `ManagedIdentityCredential` to fail fast if IMDS authentication is not available.
- [[#4669]](https://github.com/Azure/azure-sdk-for-cpp/issues/4669) Fixed the order of credentials in `DefaultAzureCredential`: `ManagedIdentityCredential` before `AzureCliCredential`.
### Other Changes
- Added support for overriding IMDS authority host in the `ManagedIdentityCredential` via `AZURE_POD_IDENTITY_AUTHORITY_HOST` environment variable.
## 1.11.0 (2025-04-08)
### Features Added

View File

@ -43,10 +43,11 @@ DefaultAzureCredential::DefaultAzureCredential(
"is the better fit for the application.");
// Creating credentials in order to ensure the order of log messages.
ChainedTokenCredential::Sources miSources;
ChainedTokenCredential::Sources credentialChain;
{
miSources.emplace_back(std::make_shared<EnvironmentCredential>(options));
miSources.emplace_back(std::make_shared<WorkloadIdentityCredential>(options));
credentialChain.emplace_back(std::make_shared<EnvironmentCredential>(options));
credentialChain.emplace_back(std::make_shared<WorkloadIdentityCredential>(options));
credentialChain.emplace_back(std::make_shared<ManagedIdentityCredential>(options));
constexpr auto envVarName = "AZURE_TOKEN_CREDENTIALS";
const auto envVarValue = Environment::GetVariable(envVarName);
@ -69,7 +70,7 @@ DefaultAzureCredential::DefaultAzureCredential(
|| StringExtensions::LocaleInvariantCaseInsensitiveEqual(trimmedEnvVarValue, "dev"))
{
IdentityLog::Write(IdentityLog::Level::Verbose, logMsg);
miSources.emplace_back(std::make_shared<AzureCliCredential>(options));
credentialChain.emplace_back(std::make_shared<AzureCliCredential>(options));
}
else
{
@ -78,14 +79,12 @@ DefaultAzureCredential::DefaultAzureCredential(
+ "' environment variable. Allowed values are 'dev' and 'prod' (case insensitive). "
"It is also valid to not have the environment variable defined.");
}
miSources.emplace_back(std::make_shared<ManagedIdentityCredential>(options));
}
// DefaultAzureCredential caches the selected credential, so that it can be reused on subsequent
// calls.
m_impl = std::make_unique<_detail::ChainedTokenCredentialImpl>(
GetCredentialName(), std::move(miSources), true);
GetCredentialName(), std::move(credentialChain), true);
}
DefaultAzureCredential::~DefaultAzureCredential() = default;

View File

@ -5,9 +5,11 @@
#include "private/identity_log.hpp"
#include <azure/core/http/transport.hpp>
#include <azure/core/internal/environment.hpp>
#include <azure/core/platform.hpp>
#include <chrono>
#include <fstream>
#include <iterator>
#include <stdexcept>
@ -21,12 +23,15 @@ using Azure::Core::_internal::Environment;
using Azure::Identity::_detail::IdentityLog;
namespace {
// https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service
// IMDS is a REST API that's available at a well-known, non-routable IP address (169.254.169.254).
// You can only access it from within the VM. Communication between the VM and IMDS never leaves the
// host.
std::string const ImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";
// First request for IMDS should not be taking tens of seconds - if IMDS is unavailable, we should
// fail fast. Among other reasons, this improves user experience when ManagedIdentityCredential is
// part of DefaultAzureCredential. Especially given that all the service credentials are earlier in
// the chain than the developer tool credentials, if ManagedIdentityCredential makes a request which
// takes 30 seconds to time out (host is not available), plus we make 3 retries of that request, and
// all that to figure out that IMDS is not available before moving on to AzureCliCredential, it will
// significantly worsen user experience when using DAC. Therefore, we need the timeout below (plus
// we have logic to not retry that request).
constexpr std::chrono::milliseconds ImdsFirstRequestConnectionTimeout = std::chrono::seconds{1};
std::string WithSourceAndClientIdMessage(std::string const& credSource, std::string const& clientId)
{
@ -496,26 +501,48 @@ std::unique_ptr<ManagedIdentitySource> ImdsManagedIdentitySource::Create(
std::string const& resourceId,
Azure::Core::Credentials::TokenCredentialOptions const& options)
{
const std::string ImdsName = "Azure Instance Metadata Service";
IdentityLog::Write(
IdentityLog::Level::Informational,
credName + " will be created"
+ WithSourceAndClientIdMessage("Azure Instance Metadata Service", clientId)
credName + " will be created" + WithSourceAndClientIdMessage(ImdsName, clientId)
+ ".\nSuccessful creation does not guarantee further successful token retrieval.");
// https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service
// IMDS is a REST API that's available at a well-known, non-routable IP address
// (169.254.169.254). You can only access it from within the VM. Communication between the VM
// and IMDS never leaves the host.
// 'AZURE_POD_IDENTITY_AUTHORITY_HOST' environment variable allows user to override the
// authority host for IMDS. This is consistent with other language SDKs.
Core::Url imdsUrl{"http://169.254.169.254"};
constexpr auto ImdsEndpointEnvVarName = "AZURE_POD_IDENTITY_AUTHORITY_HOST";
const auto imdsEndpointEnvVarValue = Environment::GetVariable(ImdsEndpointEnvVarName);
if (!imdsEndpointEnvVarValue.empty())
{
IdentityLog::Write(
IdentityLog::Level::Verbose,
credName + WithSourceAndClientIdMessage(ImdsName, {}) + ": '" + ImdsEndpointEnvVarName
+ "' environment variable is set, so customized authority host ('"
+ imdsEndpointEnvVarValue + "') will be used.");
imdsUrl = Core::Url{imdsEndpointEnvVarValue};
}
imdsUrl.SetPath("/metadata/identity/oauth2/token");
return std::unique_ptr<ManagedIdentitySource>(
new ImdsManagedIdentitySource(clientId, objectId, resourceId, options));
new ImdsManagedIdentitySource(clientId, objectId, resourceId, imdsUrl, options));
}
ImdsManagedIdentitySource::ImdsManagedIdentitySource(
std::string const& clientId,
std::string const& objectId,
std::string const& resourceId,
Azure::Core::Url const& imdsUrl,
Azure::Core::Credentials::TokenCredentialOptions const& options)
: ManagedIdentitySource(clientId, std::string(), options),
m_request(Azure::Core::Http::HttpMethod::Get, Azure::Core::Url(ImdsEndpoint))
m_request(Azure::Core::Http::HttpMethod::Get, imdsUrl)
{
{
using Azure::Core::Url;
auto& url = m_request.GetUrl();
url.AppendQueryParameter("api-version", "2018-02-01");
@ -538,6 +565,11 @@ ImdsManagedIdentitySource::ImdsManagedIdentitySource(
}
m_request.SetHeader("Metadata", "true");
Core::Credentials::TokenCredentialOptions firstRequestOptions = options;
firstRequestOptions.Retry.MaxRetries = 0;
m_firstRequestPipeline = std::make_unique<TokenCredentialImpl>(firstRequestOptions);
m_firstRequestSucceeded = false;
}
Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
@ -558,7 +590,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, true, [&]() {
std::function<std::unique_ptr<TokenRequest>()> const& createRequest = [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
if (!scopesStr.empty())
@ -567,6 +599,28 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
}
return request;
});
};
if (!m_firstRequestSucceeded)
{
std::unique_lock<std::mutex> lock(m_firstRequestMutex);
if (!m_firstRequestSucceeded)
{
const auto token = m_firstRequestPipeline->GetToken(
context.WithValue(
Core::Http::_internal::HttpConnectionTimeout, ImdsFirstRequestConnectionTimeout),
true,
createRequest);
m_firstRequestSucceeded = true;
lock.unlock();
m_firstRequestPipeline.reset();
return token;
}
}
return TokenCredentialImpl::GetToken(context, true, createRequest);
});
}

View File

@ -10,7 +10,9 @@
#include <azure/core/credentials/token_credential_options.hpp>
#include <azure/core/url.hpp>
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
@ -193,11 +195,15 @@ namespace Azure { namespace Identity { namespace _detail {
class ImdsManagedIdentitySource final : public ManagedIdentitySource {
private:
Core::Http::Request m_request;
mutable std::unique_ptr<TokenCredentialImpl> m_firstRequestPipeline;
mutable std::mutex m_firstRequestMutex;
mutable std::atomic<bool> m_firstRequestSucceeded;
explicit ImdsManagedIdentitySource(
std::string const& clientId,
std::string const& objectId,
std::string const& resourceId,
Core::Url const& imdsUrl,
Core::Credentials::TokenCredentialOptions const& options);
public:

View File

@ -241,33 +241,6 @@ TEST_P(LogMessages, )
EXPECT_EQ(
log.at(i).second, "Identity: WorkloadIdentityCredential was created successfully.");
{
const auto variableSetWording = azTokenCredsEnvVarValue.empty()
? "not set"
: ("set to '" + azTokenCredsEnvVarValue + "'");
const auto beIncludedWording = isDev ? "" : "NOT ";
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Verbose);
EXPECT_EQ(
log.at(i).second,
"Identity: DefaultAzureCredential: "
"'AZURE_TOKEN_CREDENTIALS' environment variable is "
+ variableSetWording + ", therefore AzureCliCredential will "
+ beIncludedWording + "be included in the credential chain.");
}
if (isDev)
{
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Informational);
EXPECT_EQ(
log.at(i).second,
"Identity: AzureCliCredential created."
"\nSuccessful creation does not guarantee further successful token retrieval.");
}
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Verbose);
EXPECT_EQ(
@ -304,14 +277,41 @@ TEST_P(LogMessages, )
"with Azure Instance Metadata Service source."
"\nSuccessful creation does not guarantee further successful token retrieval.");
{
const auto variableSetWording = azTokenCredsEnvVarValue.empty()
? "not set"
: ("set to '" + azTokenCredsEnvVarValue + "'");
const auto beIncludedWording = isDev ? "" : "NOT ";
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Verbose);
EXPECT_EQ(
log.at(i).second,
"Identity: DefaultAzureCredential: "
"'AZURE_TOKEN_CREDENTIALS' environment variable is "
+ variableSetWording + ", therefore AzureCliCredential will "
+ beIncludedWording + "be included in the credential chain.");
}
if (isDev)
{
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Informational);
EXPECT_EQ(
log.at(i).second,
"Identity: AzureCliCredential created."
"\nSuccessful creation does not guarantee further successful token retrieval.");
}
++i;
EXPECT_EQ(log.at(i).first, Logger::Level::Informational);
EXPECT_EQ(
log.at(i).second,
std::string(
"Identity: DefaultAzureCredential: Created with the following credentials: "
"EnvironmentCredential, WorkloadIdentityCredential, ")
+ (isDev ? "AzureCliCredential, " : "") + "ManagedIdentityCredential.");
"EnvironmentCredential, WorkloadIdentityCredential, ManagedIdentityCredential")
+ (isDev ? ", AzureCliCredential" : "") + ".");
++i;
EXPECT_EQ(i, log.size());

View File

@ -3038,4 +3038,159 @@ namespace Azure { namespace Identity { namespace Test {
EXPECT_LE(response2.AccessToken.ExpiresOn, response2.LatestExpiration + 3600s);
}
TEST(ManagedIdentityCredential, ImdsCustomHost)
{
using Azure::Core::Diagnostics::Logger;
using LogMsgVec = std::vector<std::pair<Logger::Level, std::string>>;
LogMsgVec log;
try
{
auto const actual1 = CredentialTestHelper::SimulateTokenRequest(
[](auto transport) {
TokenCredentialOptions options;
options.Transport.Transport = transport;
CredentialTestHelper::EnvironmentOverride const env({
{"MSI_ENDPOINT", ""},
{"MSI_SECRET", ""},
{"IDENTITY_ENDPOINT", "https://visualstudio.com/"},
{"IMDS_ENDPOINT", ""},
{"IDENTITY_HEADER", ""},
{"IDENTITY_SERVER_THUMBPRINT", ""},
{"AZURE_POD_IDENTITY_AUTHORITY_HOST", ""},
});
return std::make_unique<ManagedIdentityCredential>(
"fedcba98-7654-3210-0123-456789abcdef", options);
},
{{"https://azure.com/.default"}},
{"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"});
auto const actual2 = CredentialTestHelper::SimulateTokenRequest(
[](auto transport) {
TokenCredentialOptions options;
options.Transport.Transport = transport;
CredentialTestHelper::EnvironmentOverride const env({
{"MSI_ENDPOINT", ""},
{"MSI_SECRET", ""},
{"IDENTITY_ENDPOINT", ""},
{"IMDS_ENDPOINT", "https://xbox.com/"},
{"IDENTITY_HEADER", ""},
{"IDENTITY_SERVER_THUMBPRINT", ""},
{"AZURE_POD_IDENTITY_AUTHORITY_HOST", "https://custom.imds.endpoint/"},
});
return std::make_unique<ManagedIdentityCredential>(
"01234567-89ab-cdef-fedc-ba9876543210", options);
},
{{"https://outlook.com/.default"}},
{"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"});
Logger::SetLevel(Logger::Level::Verbose);
Logger::SetListener([&](auto lvl, auto msg) { log.push_back(std::make_pair(lvl, msg)); });
auto const actual3 = CredentialTestHelper::SimulateTokenRequest(
[&](auto transport) {
TokenCredentialOptions options;
options.Transport.Transport = transport;
CredentialTestHelper::EnvironmentOverride const env({
{"MSI_ENDPOINT", ""},
{"MSI_SECRET", ""},
{"IDENTITY_ENDPOINT", ""},
{"IMDS_ENDPOINT", "https://xbox.com/"},
{"IDENTITY_HEADER", ""},
{"IDENTITY_SERVER_THUMBPRINT", ""},
{"AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://localhost:59202"},
});
auto credential = std::make_unique<ManagedIdentityCredential>(
"01234567-89ab-cdef-fedc-ba9876543210", options);
EXPECT_EQ(log.size(), LogMsgVec::size_type(6));
// The 5 previous messages are verified in other tests.
EXPECT_EQ(log.at(5).first, Logger::Level::Verbose);
EXPECT_EQ(
log.at(5).second,
"Identity: ManagedIdentityCredential with Azure Instance Metadata Service source: "
"'AZURE_POD_IDENTITY_AUTHORITY_HOST' environment variable is set, "
"so customized authority host ('http://localhost:59202') will be used.");
log.clear();
return credential;
},
{{"https://outlook.com/.default"}},
{"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"});
EXPECT_EQ(actual1.Requests.size(), 1U);
EXPECT_EQ(actual1.Responses.size(), 1U);
EXPECT_EQ(actual2.Requests.size(), 1U);
EXPECT_EQ(actual2.Responses.size(), 1U);
auto const& request1 = actual1.Requests.at(0);
auto const& response1 = actual1.Responses.at(0);
auto const& request2 = actual2.Requests.at(0);
auto const& response2 = actual2.Responses.at(0);
EXPECT_EQ(request1.HttpMethod, HttpMethod::Get);
EXPECT_EQ(request2.HttpMethod, HttpMethod::Get);
EXPECT_EQ(
request1.AbsoluteUrl,
"http://169.254.169.254/metadata/identity/oauth2/token"
"?api-version=2018-02-01"
"&client_id=fedcba98-7654-3210-0123-456789abcdef"
"&resource=https%3A%2F%2Fazure.com"); // cspell:disable-line
EXPECT_EQ(
request2.AbsoluteUrl,
"https://custom.imds.endpoint/metadata/identity/oauth2/token"
"?api-version=2018-02-01"
"&client_id=01234567-89ab-cdef-fedc-ba9876543210"
"&resource=https%3A%2F%2Foutlook.com"); // cspell:disable-line
auto const& request3 = actual3.Requests.at(0);
EXPECT_EQ(
request3.AbsoluteUrl,
"http://localhost:59202/metadata/identity/oauth2/token"
"?api-version=2018-02-01"
"&client_id=01234567-89ab-cdef-fedc-ba9876543210"
"&resource=https%3A%2F%2Foutlook.com"); // cspell:disable-line
EXPECT_TRUE(request1.Body.empty());
EXPECT_TRUE(request2.Body.empty());
{
EXPECT_NE(request1.Headers.find("Metadata"), request1.Headers.end());
EXPECT_EQ(request1.Headers.at("Metadata"), "true");
EXPECT_NE(request2.Headers.find("Metadata"), request2.Headers.end());
EXPECT_EQ(request2.Headers.at("Metadata"), "true");
}
EXPECT_EQ(response1.AccessToken.Token, "ACCESSTOKEN1");
EXPECT_EQ(response2.AccessToken.Token, "ACCESSTOKEN2");
using namespace std::chrono_literals;
EXPECT_GE(response1.AccessToken.ExpiresOn, response1.EarliestExpiration + 3600s);
EXPECT_LE(response1.AccessToken.ExpiresOn, response1.LatestExpiration + 3600s);
EXPECT_GE(response2.AccessToken.ExpiresOn, response2.EarliestExpiration + 3600s);
EXPECT_LE(response2.AccessToken.ExpiresOn, response2.LatestExpiration + 3600s);
}
catch (...)
{
Logger::SetListener(nullptr);
throw;
}
Logger::SetListener(nullptr);
}
}}} // namespace Azure::Identity::Test

View File

@ -95,6 +95,7 @@
],
"defaultAction": "Allow"
},
"allowSharedKeyAccess": false,
"supportsHttpsTrafficOnly": true,
"allowBlobPublicAccess": true,
"encryption": {