diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 5df529dd9..acc325a46 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -27,6 +27,7 @@ "*.exe", "*.a", "*.lib", + "*.svg", "*.yaml", ".github/CODEOWNERS", ".github/CODEOWNERS_baseline_errors.txt", diff --git a/eng/pipelines/templates/jobs/live.tests.yml b/eng/pipelines/templates/jobs/live.tests.yml index 02d54b914..fc4333b2d 100644 --- a/eng/pipelines/templates/jobs/live.tests.yml +++ b/eng/pipelines/templates/jobs/live.tests.yml @@ -267,6 +267,9 @@ jobs: useGlobalConfig: true env: ${{ insert }}: ${{ parameters.EnvVars }} + # Set fake authority host to ensure Managed Identity fail for Default Azure Credential + # so "execute samples" step correctly picks up Azure CLI credential. + AZURE_POD_IDENTITY_AUTHORITY_HOST: 'FakeAuthorityHost' - ${{ else }}: - bash: | @@ -292,7 +295,10 @@ jobs: displayName: "Run Samples for : ${{ parameters.ServiceDirectory }}" condition: and(succeeded(), eq(variables['RunSamples'], '1')) env: - ${{ insert }}: ${{ parameters.EnvVars }} + ${{ insert }}: ${{ parameters.EnvVars }} + # Set fake authority host to ensure Managed Identity fail for Default Azure Credential + # so "execute samples" step correctly picks up Azure CLI credential. + AZURE_POD_IDENTITY_AUTHORITY_HOST: 'FakeAuthorityHost' # Make coverage targets (specified in coverage_targets.txt) and assemble # coverage report diff --git a/samples/helpers/service/src/client.cpp b/samples/helpers/service/src/client.cpp index bf5316c29..bccf42249 100644 --- a/samples/helpers/service/src/client.cpp +++ b/samples/helpers/service/src/client.cpp @@ -11,7 +11,7 @@ void Azure::Service::Client::DoSomething(const Azure::Core::Context& context) co #if (0) // Every client has its own scope. We use management.azure.com here as an example. Core::Credentials::TokenRequestContext azureServiceClientContext; - azureServiceClientContext.Scopes = {"https://management.azure.com/"}; + azureServiceClientContext.Scopes = {"https://management.azure.com/.default"}; auto authenticationToken = m_credential->GetToken(azureServiceClientContext, context); diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 7724d5891..c29d3082d 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -10,6 +10,8 @@ ### Bugs Fixed +- [[#4952]](https://github.com/Azure/azure-sdk-for-cpp/issues/4952) Improved HTTP Transport implementations' request timeouts to not exceed context deadlines. + ### Other Changes ### Acknowledgments diff --git a/sdk/core/azure-core/inc/azure/core/http/transport.hpp b/sdk/core/azure-core/inc/azure/core/http/transport.hpp index 95c01906b..c580c1369 100644 --- a/sdk/core/azure-core/inc/azure/core/http/transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/transport.hpp @@ -9,6 +9,7 @@ #pragma once #include "azure/core/context.hpp" +#include "azure/core/dll_import_export.hpp" #include "azure/core/http/http.hpp" #include "azure/core/http/raw_response.hpp" @@ -16,6 +17,15 @@ namespace Azure { namespace Core { namespace Http { + namespace _internal { + /** + * @brief A context key to use to pass connection timeout (`std::chrono::milliseconds`) to an + * HTTP transport. + * + */ + AZ_CORE_DLLEXPORT extern const Context::Key HttpConnectionTimeout; + } // namespace _internal + /** * @brief Base class for all HTTP transport implementations. */ diff --git a/sdk/core/azure-core/src/http/curl/curl.cpp b/sdk/core/azure-core/src/http/curl/curl.cpp index 2453398fe..a1178dd4e 100644 --- a/sdk/core/azure-core/src/http/curl/curl.cpp +++ b/sdk/core/azure-core/src/http/curl/curl.cpp @@ -349,9 +349,20 @@ std::unique_ptr CurlTransport::Send(Request& request, Context const // Create CurlSession to perform request Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Creating a new session."); + auto connectionTimeoutOverride = std::chrono::milliseconds{0}; + { + std::chrono::milliseconds contextConnectionTimeout{0}; + if (context.TryGetValue(Http::_internal::HttpConnectionTimeout, contextConnectionTimeout) + && contextConnectionTimeout.count() > 0) + { + connectionTimeoutOverride = contextConnectionTimeout; + } + } + auto session = std::make_unique( request, - CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection(request, m_options), + CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection( + request, m_options, connectionTimeoutOverride, false), m_options); CURLcode performing; @@ -380,6 +391,7 @@ std::unique_ptr CurlTransport::Send(Request& request, Context const CurlConnectionPool::g_curlConnectionPool.ExtractOrCreateCurlConnection( request, m_options, + connectionTimeoutOverride, getConnectionOpenIntent + 1 >= _detail::RequestPoolResetAfterConnectionFailed), m_options); } @@ -1386,11 +1398,42 @@ size_t CurlSession::ResponseBufferParser::Parse( } namespace { +// Calculates the effective timeout value, based on options.ConnectionTimeout and +// connectionTimeoutOverride. Returns 0 if default. +long GetConnectionTimeout( + CurlTransportOptions const& options, + std::chrono::milliseconds connectionTimeoutOverride) +{ + auto connectionTimeout + = (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout) + ? options.ConnectionTimeout + : std::chrono::milliseconds{0}; + + if (connectionTimeoutOverride.count() > 0) + { + connectionTimeout = connectionTimeout.count() > 0 + ? (std::min)(connectionTimeout, connectionTimeoutOverride) + : connectionTimeoutOverride; + } + + auto connectionTimeoutLong = 0; + if (connectionTimeout.count() > 0 + && connectionTimeout.count() <= std::numeric_limits::max()) + { + connectionTimeoutLong = static_cast(connectionTimeout.count()); + } + + return connectionTimeoutLong; +} + // Calculate the connection key. // The connection key is a tuple of host, proxy info, TLS info, etc. Basically any characteristics // of the connection that should indicate that the connection shouldn't be re-used should be listed // the connection key. -inline std::string GetConnectionKey(std::string const& host, CurlTransportOptions const& options) +inline std::string GetConnectionKey( + std::string const& host, + CurlTransportOptions const& options, + std::chrono::milliseconds connectionTimeoutOverride) { std::string key(host); key.append(","); @@ -1427,12 +1470,7 @@ inline std::string GetConnectionKey(std::string const& host, CurlTransportOption key.append("0"); #endif key.append(","); - // using DefaultConnectionTimeout or 0 result in the same setting - key.append( - (options.ConnectionTimeout == Azure::Core::Http::_detail::DefaultConnectionTimeout - || options.ConnectionTimeout == std::chrono::milliseconds(0)) - ? "0" - : std::to_string(options.ConnectionTimeout.count())); + key.append(std::to_string(GetConnectionTimeout(options, connectionTimeoutOverride))); return key; } @@ -2182,13 +2220,15 @@ int CurlConnection::SslCtxCallback(CURL*, void* sslctx) std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlConnection( Request& request, CurlTransportOptions const& options, + std::chrono::milliseconds connectionTimeoutOverride, bool resetPool) { uint16_t port = request.GetUrl().GetPort(); // Generate a display name for the host being connected to std::string const& hostDisplayName = request.GetUrl().GetScheme() + "://" + request.GetUrl().GetHost() + (port != 0 ? ":" + std::to_string(port) : ""); - std::string const connectionKey = GetConnectionKey(hostDisplayName, options); + std::string const connectionKey + = GetConnectionKey(hostDisplayName, options, connectionTimeoutOverride); { decltype(CurlConnectionPool::g_curlConnectionPool @@ -2239,7 +2279,8 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo // No available connection for the pool for the required host. Create one Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Spawn new connection."); - return std::make_unique(request, options, hostDisplayName, connectionKey); + return std::make_unique( + request, options, hostDisplayName, connectionKey, connectionTimeoutOverride); } // Move the connection back to the connection pool. Push it to the front so it becomes the @@ -2306,7 +2347,8 @@ CurlConnection::CurlConnection( Request& request, CurlTransportOptions const& options, std::string const& hostDisplayName, - std::string const& connectionPropertiesKey) + std::string const& connectionPropertiesKey, + std::chrono::milliseconds connectionTimeoutOverride) : m_connectionKey(connectionPropertiesKey) { m_handle = Azure::Core::_internal::UniqueHandle(curl_easy_init()); @@ -2408,15 +2450,17 @@ CurlConnection::CurlConnection( + std::string(curl_easy_strerror(result))); } - if (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout) { - if (!SetLibcurlOption(m_handle, CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result)) + const long connectionTimeout = GetConnectionTimeout(options, connectionTimeoutOverride); + if (connectionTimeout > 0) { - throw Azure::Core::Http::TransportException( - _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName - + ". Fail setting connect timeout to: " - + std::to_string(options.ConnectionTimeout.count()) + " ms. " - + std::string(curl_easy_strerror(result))); + if (!SetLibcurlOption(m_handle, CURLOPT_CONNECTTIMEOUT_MS, connectionTimeout, &result)) + { + throw Azure::Core::Http::TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + + ". Fail setting connect timeout to: " + std::to_string(connectionTimeout) + " ms. " + + std::string(curl_easy_strerror(result))); + } } } diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp index bae02d5dc..4248ff25a 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -77,6 +78,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { * @param request HTTP request to get #Azure::Core::Http::CurlNetworkConnection for. * @param options The connection settings which includes host name and libcurl handle specific * configuration. + * @param connectionTimeoutOverride If greater than 0, specifies the override value for the + * ConnectionTimeout value, specified in options. * @param resetPool Request the pool to remove all current connections for the provided * options to force the creation of a new connection. * @@ -85,6 +88,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { std::unique_ptr ExtractOrCreateCurlConnection( Request& request, CurlTransportOptions const& options, + std::chrono::milliseconds connectionTimeoutOverride = std::chrono::milliseconds{0}, bool resetPool = false); /** diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp index b122af19f..11fbec488 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp @@ -196,6 +196,8 @@ namespace Azure { namespace Core { * @param request Remote request * @param options Connection options. * @param hostDisplayName Display name for remote host, used for diagnostics. + * @param connectionTimeoutOverride If greater than 0, specifies the override value for the + * ConnectionTimeout value, specified in options. * * @param connectionPropertiesKey CURL connection properties key */ @@ -203,7 +205,8 @@ namespace Azure { namespace Core { Azure::Core::Http::Request& request, Azure::Core::Http::CurlTransportOptions const& options, std::string const& hostDisplayName, - std::string const& connectionPropertiesKey); + std::string const& connectionPropertiesKey, + std::chrono::milliseconds connectionTimeoutOverride); /** * @brief Destructor. diff --git a/sdk/core/azure-core/src/http/http.cpp b/sdk/core/azure-core/src/http/http.cpp index 5d5fa3a17..43ae1721f 100644 --- a/sdk/core/azure-core/src/http/http.cpp +++ b/sdk/core/azure-core/src/http/http.cpp @@ -3,6 +3,7 @@ #include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" #include "azure/core/internal/strings.hpp" #include "azure/core/url.hpp" @@ -12,11 +13,13 @@ using namespace Azure::Core; using namespace Azure::Core::Http; -char const Azure::Core::Http::_internal::HttpShared::ContentType[] = "content-type"; -char const Azure::Core::Http::_internal::HttpShared::ApplicationJson[] = "application/json"; -char const Azure::Core::Http::_internal::HttpShared::Accept[] = "accept"; -char const Azure::Core::Http::_internal::HttpShared::MsRequestId[] = "x-ms-request-id"; -char const Azure::Core::Http::_internal::HttpShared::MsClientRequestId[] = "x-ms-client-request-id"; +const Context::Key Http::_internal::HttpConnectionTimeout{}; + +char const Http::_internal::HttpShared::ContentType[] = "content-type"; +char const Http::_internal::HttpShared::ApplicationJson[] = "application/json"; +char const Http::_internal::HttpShared::Accept[] = "accept"; +char const Http::_internal::HttpShared::MsRequestId[] = "x-ms-request-id"; +char const Http::_internal::HttpShared::MsClientRequestId[] = "x-ms-client-request-id"; const HttpMethod HttpMethod::Get("GET"); const HttpMethod HttpMethod::Head("HEAD"); diff --git a/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp b/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp index e0da5951c..76e14a132 100644 --- a/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp +++ b/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -26,7 +27,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { std::unique_ptr<_detail::WinHttpRequest> CreateRequestHandle( Azure::Core::_internal::UniqueHandle const& connectionHandle, Azure::Core::Url const& url, - Azure::Core::Http::HttpMethod const& method); + Azure::Core::Http::HttpMethod const& method, + std::chrono::milliseconds connectionTimeout); // Callback to allow a derived transport to extract the request handle. Used for WebSocket // transports. diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp b/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp index 41dd27f12..9ce50d881 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp @@ -22,6 +22,7 @@ #include +#include #include #include #pragma warning(push) @@ -179,7 +180,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { Azure::Core::Url const& url, Azure::Core::Http::HttpMethod const& method, PCCERT_CONTEXT tlsClientCertificate, - WinHttpTransportOptions const& options); + WinHttpTransportOptions const& options, + std::chrono::milliseconds connectionTimeout); ~WinHttpRequest(); void MarkRequestHandleClosed() { m_requestHandleClosed = true; }; diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp index 3ff991e70..410aa1220 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp @@ -1257,6 +1257,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { * @param method The HTTP method to use for the request. * @param tlsClientCertificate The client certificate to use for the request. * @param options The transport options to use for the request. + * @param connectionTimeout Connection timeout in milliseconds. * * @remark Note that we *cannot* use the TlsClientCertificate field in the options passed into * this function because the creator of the associated WinHttpTransport object may have freed the @@ -1269,7 +1270,8 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { Azure::Core::Url const& url, Azure::Core::Http::HttpMethod const& method, PCCERT_CONTEXT tlsClientCertificate, - WinHttpTransportOptions const& options) + WinHttpTransportOptions const& options, + std::chrono::milliseconds connectionTimeout) : m_expectedTlsRootCertificates(options.ExpectedTlsRootCertificates), m_tlsClientCertificate(CertDuplicateCertificateContext(tlsClientCertificate)) { @@ -1324,6 +1326,20 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { } } + if (connectionTimeout.count() > 0 + && connectionTimeout.count() < std::numeric_limits::max()) + { + auto timeoutMillisecondsULong = static_cast(connectionTimeout.count()); + if (!WinHttpSetOption( + m_requestHandle.get(), + WINHTTP_OPTION_CONNECT_TIMEOUT, + &timeoutMillisecondsULong, + sizeof(timeoutMillisecondsULong))) + { + GetErrorAndThrow("Error while setting connection timeout."); + } + } + if (!options.ProxyInformation.empty()) { WINHTTP_PROXY_INFO proxyInfo{}; @@ -1436,10 +1452,11 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { std::unique_ptr WinHttpTransportImpl::CreateRequestHandle( Azure::Core::_internal::UniqueHandle const& connectionHandle, Azure::Core::Url const& url, - Azure::Core::Http::HttpMethod const& method) + Azure::Core::Http::HttpMethod const& method, + std::chrono::milliseconds connectionTimeout) { auto request{std::make_unique( - connectionHandle, url, method, m_tlsClientCertificate.get(), m_options)}; + connectionHandle, url, method, m_tlsClientCertificate.get(), m_options, connectionTimeout)}; // If we are supporting WebSockets, then let WinHTTP know that it should // prepare to upgrade the HttpRequest to a WebSocket. if (HasWebSocketSupport()) @@ -1797,10 +1814,20 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { std::unique_ptr WinHttpTransportImpl::Send(Request& request, Context const& context) { + auto connectionTimeout = std::chrono::milliseconds{0}; + { + std::chrono::milliseconds contextConnectionTimeout{0}; + if (context.TryGetValue(Http::_internal::HttpConnectionTimeout, contextConnectionTimeout) + && contextConnectionTimeout.count() > 0) + { + connectionTimeout = contextConnectionTimeout; + } + } + Azure::Core::_internal::UniqueHandle connectionHandle = CreateConnectionHandle(request.GetUrl(), context); - std::unique_ptr<_detail::WinHttpRequest> requestHandle( - CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod())); + std::unique_ptr<_detail::WinHttpRequest> requestHandle(CreateRequestHandle( + connectionHandle, request.GetUrl(), request.GetMethod(), connectionTimeout)); requestHandle->SendRequest(request, context); requestHandle->ReceiveResponse(context); diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 2ed6e31a6..b14883724 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -10,8 +10,13 @@ ### Bugs Fixed +- [[#4952]](https://github.com/Azure/azure-sdk-for-cpp/issues/4952) Fixed `ManagedIdentityCredential` to fail fast if IMDS authentication is not available. +- [[#4669]](https://github.com/Azure/azure-sdk-for-cpp/issues/4669) Fixed the order of credentials in `DefaultAzureCredential`: `ManagedIdentityCredential` before `AzureCliCredential`. + ### Other Changes +- Added support for overriding IMDS authority host in the `ManagedIdentityCredential` via `AZURE_POD_IDENTITY_AUTHORITY_HOST` environment variable. + ## 1.11.0 (2025-04-08) ### Features Added diff --git a/sdk/identity/azure-identity/src/default_azure_credential.cpp b/sdk/identity/azure-identity/src/default_azure_credential.cpp index d5c7023ed..1b859893d 100644 --- a/sdk/identity/azure-identity/src/default_azure_credential.cpp +++ b/sdk/identity/azure-identity/src/default_azure_credential.cpp @@ -43,10 +43,11 @@ DefaultAzureCredential::DefaultAzureCredential( "is the better fit for the application."); // Creating credentials in order to ensure the order of log messages. - ChainedTokenCredential::Sources miSources; + ChainedTokenCredential::Sources credentialChain; { - miSources.emplace_back(std::make_shared(options)); - miSources.emplace_back(std::make_shared(options)); + credentialChain.emplace_back(std::make_shared(options)); + credentialChain.emplace_back(std::make_shared(options)); + credentialChain.emplace_back(std::make_shared(options)); constexpr auto envVarName = "AZURE_TOKEN_CREDENTIALS"; const auto envVarValue = Environment::GetVariable(envVarName); @@ -69,7 +70,7 @@ DefaultAzureCredential::DefaultAzureCredential( || StringExtensions::LocaleInvariantCaseInsensitiveEqual(trimmedEnvVarValue, "dev")) { IdentityLog::Write(IdentityLog::Level::Verbose, logMsg); - miSources.emplace_back(std::make_shared(options)); + credentialChain.emplace_back(std::make_shared(options)); } else { @@ -78,14 +79,12 @@ DefaultAzureCredential::DefaultAzureCredential( + "' environment variable. Allowed values are 'dev' and 'prod' (case insensitive). " "It is also valid to not have the environment variable defined."); } - - miSources.emplace_back(std::make_shared(options)); } // DefaultAzureCredential caches the selected credential, so that it can be reused on subsequent // calls. m_impl = std::make_unique<_detail::ChainedTokenCredentialImpl>( - GetCredentialName(), std::move(miSources), true); + GetCredentialName(), std::move(credentialChain), true); } DefaultAzureCredential::~DefaultAzureCredential() = default; diff --git a/sdk/identity/azure-identity/src/managed_identity_source.cpp b/sdk/identity/azure-identity/src/managed_identity_source.cpp index 20202800a..93a5969bf 100644 --- a/sdk/identity/azure-identity/src/managed_identity_source.cpp +++ b/sdk/identity/azure-identity/src/managed_identity_source.cpp @@ -5,9 +5,11 @@ #include "private/identity_log.hpp" +#include #include #include +#include #include #include #include @@ -21,12 +23,15 @@ using Azure::Core::_internal::Environment; using Azure::Identity::_detail::IdentityLog; namespace { - -// https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service -// IMDS is a REST API that's available at a well-known, non-routable IP address (169.254.169.254). -// You can only access it from within the VM. Communication between the VM and IMDS never leaves the -// host. -std::string const ImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; +// First request for IMDS should not be taking tens of seconds - if IMDS is unavailable, we should +// fail fast. Among other reasons, this improves user experience when ManagedIdentityCredential is +// part of DefaultAzureCredential. Especially given that all the service credentials are earlier in +// the chain than the developer tool credentials, if ManagedIdentityCredential makes a request which +// takes 30 seconds to time out (host is not available), plus we make 3 retries of that request, and +// all that to figure out that IMDS is not available before moving on to AzureCliCredential, it will +// significantly worsen user experience when using DAC. Therefore, we need the timeout below (plus +// we have logic to not retry that request). +constexpr std::chrono::milliseconds ImdsFirstRequestConnectionTimeout = std::chrono::seconds{1}; std::string WithSourceAndClientIdMessage(std::string const& credSource, std::string const& clientId) { @@ -496,26 +501,48 @@ std::unique_ptr ImdsManagedIdentitySource::Create( std::string const& resourceId, Azure::Core::Credentials::TokenCredentialOptions const& options) { + const std::string ImdsName = "Azure Instance Metadata Service"; + IdentityLog::Write( IdentityLog::Level::Informational, - credName + " will be created" - + WithSourceAndClientIdMessage("Azure Instance Metadata Service", clientId) + credName + " will be created" + WithSourceAndClientIdMessage(ImdsName, clientId) + ".\nSuccessful creation does not guarantee further successful token retrieval."); + // https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service + // IMDS is a REST API that's available at a well-known, non-routable IP address + // (169.254.169.254). You can only access it from within the VM. Communication between the VM + // and IMDS never leaves the host. + // 'AZURE_POD_IDENTITY_AUTHORITY_HOST' environment variable allows user to override the + // authority host for IMDS. This is consistent with other language SDKs. + Core::Url imdsUrl{"http://169.254.169.254"}; + constexpr auto ImdsEndpointEnvVarName = "AZURE_POD_IDENTITY_AUTHORITY_HOST"; + const auto imdsEndpointEnvVarValue = Environment::GetVariable(ImdsEndpointEnvVarName); + if (!imdsEndpointEnvVarValue.empty()) + { + IdentityLog::Write( + IdentityLog::Level::Verbose, + credName + WithSourceAndClientIdMessage(ImdsName, {}) + ": '" + ImdsEndpointEnvVarName + + "' environment variable is set, so customized authority host ('" + + imdsEndpointEnvVarValue + "') will be used."); + + imdsUrl = Core::Url{imdsEndpointEnvVarValue}; + } + imdsUrl.SetPath("/metadata/identity/oauth2/token"); + return std::unique_ptr( - new ImdsManagedIdentitySource(clientId, objectId, resourceId, options)); + new ImdsManagedIdentitySource(clientId, objectId, resourceId, imdsUrl, options)); } ImdsManagedIdentitySource::ImdsManagedIdentitySource( std::string const& clientId, std::string const& objectId, std::string const& resourceId, + Azure::Core::Url const& imdsUrl, Azure::Core::Credentials::TokenCredentialOptions const& options) : ManagedIdentitySource(clientId, std::string(), options), - m_request(Azure::Core::Http::HttpMethod::Get, Azure::Core::Url(ImdsEndpoint)) + m_request(Azure::Core::Http::HttpMethod::Get, imdsUrl) { { - using Azure::Core::Url; auto& url = m_request.GetUrl(); url.AppendQueryParameter("api-version", "2018-02-01"); @@ -538,6 +565,11 @@ ImdsManagedIdentitySource::ImdsManagedIdentitySource( } m_request.SetHeader("Metadata", "true"); + + Core::Credentials::TokenCredentialOptions firstRequestOptions = options; + firstRequestOptions.Retry.MaxRetries = 0; + m_firstRequestPipeline = std::make_unique(firstRequestOptions); + m_firstRequestSucceeded = false; } Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( @@ -558,7 +590,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() { - return TokenCredentialImpl::GetToken(context, true, [&]() { + std::function()> const& createRequest = [&]() { auto request = std::make_unique(m_request); if (!scopesStr.empty()) @@ -567,6 +599,28 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( } return request; - }); + }; + + if (!m_firstRequestSucceeded) + { + std::unique_lock lock(m_firstRequestMutex); + if (!m_firstRequestSucceeded) + { + const auto token = m_firstRequestPipeline->GetToken( + context.WithValue( + Core::Http::_internal::HttpConnectionTimeout, ImdsFirstRequestConnectionTimeout), + true, + createRequest); + + m_firstRequestSucceeded = true; + + lock.unlock(); + m_firstRequestPipeline.reset(); + + return token; + } + } + + return TokenCredentialImpl::GetToken(context, true, createRequest); }); } diff --git a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp index bdd005941..4b86f99ab 100644 --- a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp +++ b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp @@ -10,7 +10,9 @@ #include #include +#include #include +#include #include #include @@ -193,11 +195,15 @@ namespace Azure { namespace Identity { namespace _detail { class ImdsManagedIdentitySource final : public ManagedIdentitySource { private: Core::Http::Request m_request; + mutable std::unique_ptr m_firstRequestPipeline; + mutable std::mutex m_firstRequestMutex; + mutable std::atomic m_firstRequestSucceeded; explicit ImdsManagedIdentitySource( std::string const& clientId, std::string const& objectId, std::string const& resourceId, + Core::Url const& imdsUrl, Core::Credentials::TokenCredentialOptions const& options); public: diff --git a/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp b/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp index 6e90c4ce3..0f6cad58c 100644 --- a/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/default_azure_credential_test.cpp @@ -241,33 +241,6 @@ TEST_P(LogMessages, ) EXPECT_EQ( log.at(i).second, "Identity: WorkloadIdentityCredential was created successfully."); - { - const auto variableSetWording = azTokenCredsEnvVarValue.empty() - ? "not set" - : ("set to '" + azTokenCredsEnvVarValue + "'"); - - const auto beIncludedWording = isDev ? "" : "NOT "; - - ++i; - EXPECT_EQ(log.at(i).first, Logger::Level::Verbose); - EXPECT_EQ( - log.at(i).second, - "Identity: DefaultAzureCredential: " - "'AZURE_TOKEN_CREDENTIALS' environment variable is " - + variableSetWording + ", therefore AzureCliCredential will " - + beIncludedWording + "be included in the credential chain."); - } - - if (isDev) - { - ++i; - EXPECT_EQ(log.at(i).first, Logger::Level::Informational); - EXPECT_EQ( - log.at(i).second, - "Identity: AzureCliCredential created." - "\nSuccessful creation does not guarantee further successful token retrieval."); - } - ++i; EXPECT_EQ(log.at(i).first, Logger::Level::Verbose); EXPECT_EQ( @@ -304,14 +277,41 @@ TEST_P(LogMessages, ) "with Azure Instance Metadata Service source." "\nSuccessful creation does not guarantee further successful token retrieval."); + { + const auto variableSetWording = azTokenCredsEnvVarValue.empty() + ? "not set" + : ("set to '" + azTokenCredsEnvVarValue + "'"); + + const auto beIncludedWording = isDev ? "" : "NOT "; + + ++i; + EXPECT_EQ(log.at(i).first, Logger::Level::Verbose); + EXPECT_EQ( + log.at(i).second, + "Identity: DefaultAzureCredential: " + "'AZURE_TOKEN_CREDENTIALS' environment variable is " + + variableSetWording + ", therefore AzureCliCredential will " + + beIncludedWording + "be included in the credential chain."); + } + + if (isDev) + { + ++i; + EXPECT_EQ(log.at(i).first, Logger::Level::Informational); + EXPECT_EQ( + log.at(i).second, + "Identity: AzureCliCredential created." + "\nSuccessful creation does not guarantee further successful token retrieval."); + } + ++i; EXPECT_EQ(log.at(i).first, Logger::Level::Informational); EXPECT_EQ( log.at(i).second, std::string( "Identity: DefaultAzureCredential: Created with the following credentials: " - "EnvironmentCredential, WorkloadIdentityCredential, ") - + (isDev ? "AzureCliCredential, " : "") + "ManagedIdentityCredential."); + "EnvironmentCredential, WorkloadIdentityCredential, ManagedIdentityCredential") + + (isDev ? ", AzureCliCredential" : "") + "."); ++i; EXPECT_EQ(i, log.size()); diff --git a/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp b/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp index ba82eb0d1..9fa671b6f 100644 --- a/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/managed_identity_credential_test.cpp @@ -3038,4 +3038,159 @@ namespace Azure { namespace Identity { namespace Test { EXPECT_LE(response2.AccessToken.ExpiresOn, response2.LatestExpiration + 3600s); } + TEST(ManagedIdentityCredential, ImdsCustomHost) + { + using Azure::Core::Diagnostics::Logger; + using LogMsgVec = std::vector>; + LogMsgVec log; + + try + { + auto const actual1 = CredentialTestHelper::SimulateTokenRequest( + [](auto transport) { + TokenCredentialOptions options; + options.Transport.Transport = transport; + + CredentialTestHelper::EnvironmentOverride const env({ + {"MSI_ENDPOINT", ""}, + {"MSI_SECRET", ""}, + {"IDENTITY_ENDPOINT", "https://visualstudio.com/"}, + {"IMDS_ENDPOINT", ""}, + {"IDENTITY_HEADER", ""}, + {"IDENTITY_SERVER_THUMBPRINT", ""}, + {"AZURE_POD_IDENTITY_AUTHORITY_HOST", ""}, + }); + + return std::make_unique( + "fedcba98-7654-3210-0123-456789abcdef", options); + }, + {{"https://azure.com/.default"}}, + {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}"}); + + auto const actual2 = CredentialTestHelper::SimulateTokenRequest( + [](auto transport) { + TokenCredentialOptions options; + options.Transport.Transport = transport; + + CredentialTestHelper::EnvironmentOverride const env({ + {"MSI_ENDPOINT", ""}, + {"MSI_SECRET", ""}, + {"IDENTITY_ENDPOINT", ""}, + {"IMDS_ENDPOINT", "https://xbox.com/"}, + {"IDENTITY_HEADER", ""}, + {"IDENTITY_SERVER_THUMBPRINT", ""}, + {"AZURE_POD_IDENTITY_AUTHORITY_HOST", "https://custom.imds.endpoint/"}, + }); + + return std::make_unique( + "01234567-89ab-cdef-fedc-ba9876543210", options); + }, + {{"https://outlook.com/.default"}}, + {"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"}); + + Logger::SetLevel(Logger::Level::Verbose); + Logger::SetListener([&](auto lvl, auto msg) { log.push_back(std::make_pair(lvl, msg)); }); + + auto const actual3 = CredentialTestHelper::SimulateTokenRequest( + [&](auto transport) { + TokenCredentialOptions options; + options.Transport.Transport = transport; + + CredentialTestHelper::EnvironmentOverride const env({ + {"MSI_ENDPOINT", ""}, + {"MSI_SECRET", ""}, + {"IDENTITY_ENDPOINT", ""}, + {"IMDS_ENDPOINT", "https://xbox.com/"}, + {"IDENTITY_HEADER", ""}, + {"IDENTITY_SERVER_THUMBPRINT", ""}, + {"AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://localhost:59202"}, + }); + + auto credential = std::make_unique( + "01234567-89ab-cdef-fedc-ba9876543210", options); + + EXPECT_EQ(log.size(), LogMsgVec::size_type(6)); + + // The 5 previous messages are verified in other tests. + EXPECT_EQ(log.at(5).first, Logger::Level::Verbose); + EXPECT_EQ( + log.at(5).second, + "Identity: ManagedIdentityCredential with Azure Instance Metadata Service source: " + "'AZURE_POD_IDENTITY_AUTHORITY_HOST' environment variable is set, " + "so customized authority host ('http://localhost:59202') will be used."); + + log.clear(); + + return credential; + }, + {{"https://outlook.com/.default"}}, + {"{\"expires_in\":7200, \"access_token\":\"ACCESSTOKEN2\"}"}); + + EXPECT_EQ(actual1.Requests.size(), 1U); + EXPECT_EQ(actual1.Responses.size(), 1U); + + EXPECT_EQ(actual2.Requests.size(), 1U); + EXPECT_EQ(actual2.Responses.size(), 1U); + + auto const& request1 = actual1.Requests.at(0); + auto const& response1 = actual1.Responses.at(0); + + auto const& request2 = actual2.Requests.at(0); + auto const& response2 = actual2.Responses.at(0); + + EXPECT_EQ(request1.HttpMethod, HttpMethod::Get); + EXPECT_EQ(request2.HttpMethod, HttpMethod::Get); + + EXPECT_EQ( + request1.AbsoluteUrl, + "http://169.254.169.254/metadata/identity/oauth2/token" + "?api-version=2018-02-01" + "&client_id=fedcba98-7654-3210-0123-456789abcdef" + "&resource=https%3A%2F%2Fazure.com"); // cspell:disable-line + + EXPECT_EQ( + request2.AbsoluteUrl, + "https://custom.imds.endpoint/metadata/identity/oauth2/token" + "?api-version=2018-02-01" + "&client_id=01234567-89ab-cdef-fedc-ba9876543210" + "&resource=https%3A%2F%2Foutlook.com"); // cspell:disable-line + + auto const& request3 = actual3.Requests.at(0); + EXPECT_EQ( + request3.AbsoluteUrl, + "http://localhost:59202/metadata/identity/oauth2/token" + "?api-version=2018-02-01" + "&client_id=01234567-89ab-cdef-fedc-ba9876543210" + "&resource=https%3A%2F%2Foutlook.com"); // cspell:disable-line + + EXPECT_TRUE(request1.Body.empty()); + EXPECT_TRUE(request2.Body.empty()); + + { + EXPECT_NE(request1.Headers.find("Metadata"), request1.Headers.end()); + EXPECT_EQ(request1.Headers.at("Metadata"), "true"); + + EXPECT_NE(request2.Headers.find("Metadata"), request2.Headers.end()); + EXPECT_EQ(request2.Headers.at("Metadata"), "true"); + } + + EXPECT_EQ(response1.AccessToken.Token, "ACCESSTOKEN1"); + EXPECT_EQ(response2.AccessToken.Token, "ACCESSTOKEN2"); + + using namespace std::chrono_literals; + EXPECT_GE(response1.AccessToken.ExpiresOn, response1.EarliestExpiration + 3600s); + EXPECT_LE(response1.AccessToken.ExpiresOn, response1.LatestExpiration + 3600s); + + EXPECT_GE(response2.AccessToken.ExpiresOn, response2.EarliestExpiration + 3600s); + EXPECT_LE(response2.AccessToken.ExpiresOn, response2.LatestExpiration + 3600s); + } + catch (...) + { + Logger::SetListener(nullptr); + throw; + } + + Logger::SetListener(nullptr); + } + }}} // namespace Azure::Identity::Test diff --git a/sdk/tables/test-resources.json b/sdk/tables/test-resources.json index abf9292d6..f80e96353 100644 --- a/sdk/tables/test-resources.json +++ b/sdk/tables/test-resources.json @@ -95,6 +95,7 @@ ], "defaultAction": "Allow" }, + "allowSharedKeyAccess": false, "supportsHttpsTrafficOnly": true, "allowBlobPublicAccess": true, "encryption": {