diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index a20d10996..57fe6bbc2 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -77,6 +77,7 @@ set( src/managed_identity_credential.cpp src/managed_identity_source.cpp src/private/chained_token_credential_impl.hpp + src/private/client_assertion_credential_impl.hpp src/private/identity_log.hpp src/private/managed_identity_source.hpp src/private/package_version.hpp diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_pipelines_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_pipelines_credential.hpp index 0a96fc7b8..fb50cffb2 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_pipelines_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_pipelines_credential.hpp @@ -8,6 +8,7 @@ #pragma once +#include "azure/identity/client_assertion_credential.hpp" #include "azure/identity/detail/client_credential_core.hpp" #include "azure/identity/detail/token_cache.hpp" @@ -20,7 +21,7 @@ namespace Azure { namespace Identity { namespace _detail { - class TokenCredentialImpl; + class ClientAssertionCredentialImpl; } // namespace _detail /** @@ -57,12 +58,9 @@ namespace Azure { namespace Identity { private: std::string m_serviceConnectionId; std::string m_systemAccessToken; - _detail::ClientCredentialCore m_clientCredentialCore; Azure::Core::Http::_internal::HttpPipeline m_httpPipeline; std::string m_oidcRequestUrl; - std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; - std::string m_requestBody; - _detail::TokenCache m_tokenCache; + std::unique_ptr<_detail::ClientAssertionCredentialImpl> m_clientAssertionCredentialImpl; std::string GetAssertion(Core::Context const& context) const; Azure::Core::Http::Request CreateOidcRequestMessage() const; diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_assertion_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_assertion_credential.hpp index e1fd87552..3a9af365a 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_assertion_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_assertion_credential.hpp @@ -9,17 +9,15 @@ #pragma once #include "azure/identity/detail/client_credential_core.hpp" -#include "azure/identity/detail/token_cache.hpp" #include -#include #include #include namespace Azure { namespace Identity { namespace _detail { - class TokenCredentialImpl; + class ClientAssertionCredentialImpl; } // namespace _detail /** @@ -55,11 +53,7 @@ namespace Azure { namespace Identity { */ class ClientAssertionCredential final : public Core::Credentials::TokenCredential { private: - std::function m_assertionCallback; - _detail::ClientCredentialCore m_clientCredentialCore; - std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; - std::string m_requestBody; - _detail::TokenCache m_tokenCache; + std::unique_ptr<_detail::ClientAssertionCredentialImpl> m_impl; public: /** diff --git a/sdk/identity/azure-identity/inc/azure/identity/workload_identity_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/workload_identity_credential.hpp index b1e9798e3..0e481e99c 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/workload_identity_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/workload_identity_credential.hpp @@ -8,6 +8,7 @@ #pragma once +#include "azure/identity/client_assertion_credential.hpp" #include "azure/identity/detail/client_credential_core.hpp" #include "azure/identity/detail/token_cache.hpp" @@ -18,7 +19,7 @@ namespace Azure { namespace Identity { namespace _detail { - class TokenCredentialImpl; + class ClientAssertionCredentialImpl; } // namespace _detail /** @@ -74,12 +75,11 @@ namespace Azure { namespace Identity { */ class WorkloadIdentityCredential final : public Core::Credentials::TokenCredential { private: - _detail::TokenCache m_tokenCache; - _detail::ClientCredentialCore m_clientCredentialCore; - std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; - std::string m_requestBody; + std::unique_ptr<_detail::ClientAssertionCredentialImpl> m_clientAssertionCredentialImpl; std::string m_tokenFilePath; + std::string GetAssertion(Core::Context const& context) const; + public: /** * @brief Constructs a Workload Identity Credential. diff --git a/sdk/identity/azure-identity/src/azure_pipelines_credential.cpp b/sdk/identity/azure-identity/src/azure_pipelines_credential.cpp index 2fb8f0757..3665fe6a5 100644 --- a/sdk/identity/azure-identity/src/azure_pipelines_credential.cpp +++ b/sdk/identity/azure-identity/src/azure_pipelines_credential.cpp @@ -3,10 +3,10 @@ #include "azure/identity/azure_pipelines_credential.hpp" +#include "private/client_assertion_credential_impl.hpp" #include "private/identity_log.hpp" #include "private/package_version.hpp" #include "private/tenant_id_resolver.hpp" -#include "private/token_credential_impl.hpp" #include @@ -28,30 +28,6 @@ using Azure::Core::Json::_internal::json; using Azure::Identity::_detail::IdentityLog; using Azure::Identity::_detail::PackageVersion; using Azure::Identity::_detail::TenantIdResolver; -using Azure::Identity::_detail::TokenCredentialImpl; - -namespace { -bool IsValidTenantId(std::string const& tenantId) -{ - const std::string allowedChars = ".-"; - if (tenantId.empty()) - { - return false; - } - for (auto const c : tenantId) - { - if (allowedChars.find(c) != std::string::npos) - { - continue; - } - if (!StringExtensions::IsAlphaNumeric(c)) - { - return false; - } - } - return true; -} -} // namespace AzurePipelinesCredential::AzurePipelinesCredential( std::string tenantId, @@ -61,26 +37,10 @@ AzurePipelinesCredential::AzurePipelinesCredential( AzurePipelinesCredentialOptions const& options) : TokenCredential("AzurePipelinesCredential"), m_serviceConnectionId(serviceConnectionId), m_systemAccessToken(systemAccessToken), - m_clientCredentialCore(tenantId, options.AuthorityHost, options.AdditionallyAllowedTenants), m_httpPipeline(HttpPipeline(options, "identity", PackageVersion::ToString(), {}, {})) { m_oidcRequestUrl = _detail::DefaultOptionValues::GetOidcRequestUrl(); - bool isTenantIdValid = IsValidTenantId(tenantId); - if (!isTenantIdValid) - { - IdentityLog::Write( - IdentityLog::Level::Warning, - "Invalid tenant ID provided for " + GetCredentialName() - + ". The tenant ID must be a non-empty string containing only alphanumeric characters, " - "periods, or hyphens. You can locate your tenant ID by following the instructions " - "listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names"); - } - if (clientId.empty()) - { - IdentityLog::Write( - IdentityLog::Level::Warning, "No client ID specified for " + GetCredentialName() + "."); - } if (serviceConnectionId.empty()) { IdentityLog::Write( @@ -101,20 +61,24 @@ AzurePipelinesCredential::AzurePipelinesCredential( + "' needed by " + GetCredentialName() + ". This should be set by Azure Pipelines."); } - if (isTenantIdValid && !clientId.empty() && !serviceConnectionId.empty() - && !systemAccessToken.empty() && !m_oidcRequestUrl.empty()) + if (TenantIdResolver::IsValidTenantId(tenantId) && !clientId.empty() + && !serviceConnectionId.empty() && !systemAccessToken.empty() && !m_oidcRequestUrl.empty()) { - m_tokenCredentialImpl = std::make_unique(options); - m_requestBody - = std::string( - "grant_type=client_credentials" - "&client_assertion_type=" - "urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" // cspell:disable-line - "&client_id=") - + Url::Encode(clientId); + ClientAssertionCredentialOptions clientAssertionCredentialOptions{}; + // Get the options from the base class (including ClientOptions). + static_cast(clientAssertionCredentialOptions) + = options; + clientAssertionCredentialOptions.AuthorityHost = options.AuthorityHost; + clientAssertionCredentialOptions.AdditionallyAllowedTenants + = options.AdditionallyAllowedTenants; - IdentityLog::Write( - IdentityLog::Level::Informational, GetCredentialName() + " was created successfully."); + std::function callback + = [this](Context const& context) { return GetAssertion(context); }; + + // ClientAssertionCredential validates the tenant ID, client ID, and assertion callback and logs + // warning messages otherwise. + m_clientAssertionCredentialImpl = std::make_unique<_detail::ClientAssertionCredentialImpl>( + GetCredentialName(), tenantId, clientId, callback, clientAssertionCredentialOptions); } else { @@ -214,7 +178,7 @@ AccessToken AzurePipelinesCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - if (!m_tokenCredentialImpl) + if (!m_clientAssertionCredentialImpl) { auto const AuthUnavailable = GetCredentialName() + " authentication unavailable. "; @@ -226,41 +190,6 @@ AccessToken AzurePipelinesCredential::GetToken( AuthUnavailable + "Azure Pipelines environment is not set up correctly."); } - auto const tenantId = TenantIdResolver::Resolve( - m_clientCredentialCore.GetTenantId(), - tokenRequestContext, - m_clientCredentialCore.GetAdditionallyAllowedTenants()); - - auto const scopesStr - = m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes); - - // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda - // argument when they are being executed. They are not supposed to keep a reference to lambda - // argument to call it later. Therefore, any capture made here will outlive the possible time - // frame when the lambda might get called. - return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() { - return m_tokenCredentialImpl->GetToken(context, false, [&]() { - auto body = m_requestBody; - if (!scopesStr.empty()) - { - body += "&scope=" + scopesStr; - } - - // Get the request url before calling GetAssertion to validate the authority host scheme. - // This is to avoid making a request to the OIDC endpoint if the authority host scheme is - // invalid. - auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId); - - const std::string assertion = GetAssertion(context); - - body += "&client_assertion=" + Azure::Core::Url::Encode(assertion); - - auto request - = std::make_unique(HttpMethod::Post, requestUrl, body); - - request->HttpRequest.SetHeader("Host", requestUrl.GetHost()); - - return request; - }); - }); + return m_clientAssertionCredentialImpl->GetToken( + GetCredentialName(), tokenRequestContext, context); } diff --git a/sdk/identity/azure-identity/src/client_assertion_credential.cpp b/sdk/identity/azure-identity/src/client_assertion_credential.cpp index 596a06576..822451684 100644 --- a/sdk/identity/azure-identity/src/client_assertion_credential.cpp +++ b/sdk/identity/azure-identity/src/client_assertion_credential.cpp @@ -3,15 +3,16 @@ #include "azure/identity/client_assertion_credential.hpp" +#include "private/client_assertion_credential_impl.hpp" #include "private/identity_log.hpp" #include "private/package_version.hpp" #include "private/tenant_id_resolver.hpp" -#include "private/token_credential_impl.hpp" #include using Azure::Identity::ClientAssertionCredential; using Azure::Identity::ClientAssertionCredentialOptions; +using Azure::Identity::_detail::ClientAssertionCredentialImpl; using Azure::Core::Context; using Azure::Core::Url; @@ -24,44 +25,21 @@ using Azure::Identity::_detail::IdentityLog; using Azure::Identity::_detail::TenantIdResolver; using Azure::Identity::_detail::TokenCredentialImpl; -namespace { -bool IsValidTenantId(std::string const& tenantId) -{ - const std::string allowedChars = ".-"; - if (tenantId.empty()) - { - return false; - } - for (auto const c : tenantId) - { - if (allowedChars.find(c) != std::string::npos) - { - continue; - } - if (!StringExtensions::IsAlphaNumeric(c)) - { - return false; - } - } - return true; -} -} // namespace - -ClientAssertionCredential::ClientAssertionCredential( +ClientAssertionCredentialImpl::ClientAssertionCredentialImpl( + std::string const& credentialName, std::string tenantId, std::string clientId, std::function assertionCallback, ClientAssertionCredentialOptions const& options) - : TokenCredential("ClientAssertionCredential"), - m_assertionCallback(std::move(assertionCallback)), + : m_assertionCallback(std::move(assertionCallback)), m_clientCredentialCore(tenantId, options.AuthorityHost, options.AdditionallyAllowedTenants) { - bool isTenantIdValid = IsValidTenantId(tenantId); + bool isTenantIdValid = TenantIdResolver::IsValidTenantId(tenantId); if (!isTenantIdValid) { IdentityLog::Write( IdentityLog::Level::Warning, - GetCredentialName() + credentialName + ": Invalid tenant ID provided. The tenant ID must be a non-empty string containing " "only alphanumeric characters, periods, or hyphens. You can locate your tenant ID by " "following the instructions listed here: " @@ -69,14 +47,13 @@ ClientAssertionCredential::ClientAssertionCredential( } if (clientId.empty()) { - IdentityLog::Write( - IdentityLog::Level::Warning, GetCredentialName() + ": No client ID specified."); + IdentityLog::Write(IdentityLog::Level::Warning, credentialName + ": No client ID specified."); } if (!m_assertionCallback) { IdentityLog::Write( IdentityLog::Level::Warning, - GetCredentialName() + credentialName + ": The assertionCallback must be a valid function that returns assertions."); } @@ -92,7 +69,7 @@ ClientAssertionCredential::ClientAssertionCredential( + Url::Encode(clientId); IdentityLog::Write( - IdentityLog::Level::Informational, GetCredentialName() + " was created successfully."); + IdentityLog::Level::Informational, credentialName + " was created successfully."); } else { @@ -101,23 +78,22 @@ ClientAssertionCredential::ClientAssertionCredential( // primarily needed for credentials that are part of the DefaultAzureCredential, which this // credential is not intended for. IdentityLog::Write( - IdentityLog::Level::Warning, GetCredentialName() + " was not initialized correctly."); + IdentityLog::Level::Warning, credentialName + " was not initialized correctly."); } } -ClientAssertionCredential::~ClientAssertionCredential() = default; - -AccessToken ClientAssertionCredential::GetToken( +AccessToken ClientAssertionCredentialImpl::GetToken( + std::string const& credentialName, TokenRequestContext const& tokenRequestContext, Context const& context) const { if (!m_tokenCredentialImpl) { - auto const AuthUnavailable = GetCredentialName() + " authentication unavailable. "; + auto const AuthUnavailable = credentialName + " authentication unavailable. "; IdentityLog::Write( IdentityLog::Level::Warning, - AuthUnavailable + "See earlier " + GetCredentialName() + " log messages for details."); + AuthUnavailable + "See earlier " + credentialName + " log messages for details."); throw AuthenticationException(AuthUnavailable); } @@ -160,3 +136,27 @@ AccessToken ClientAssertionCredential::GetToken( }); }); } + +ClientAssertionCredential::ClientAssertionCredential( + std::string tenantId, + std::string clientId, + std::function assertionCallback, + ClientAssertionCredentialOptions const& options) + : TokenCredential("ClientAssertionCredential"), + m_impl(std::make_unique( + GetCredentialName(), + tenantId, + clientId, + assertionCallback, + options)) +{ +} + +ClientAssertionCredential::~ClientAssertionCredential() = default; + +AccessToken ClientAssertionCredential::GetToken( + TokenRequestContext const& tokenRequestContext, + Context const& context) const +{ + return m_impl->GetToken(GetCredentialName(), tokenRequestContext, context); +} diff --git a/sdk/identity/azure-identity/src/private/client_assertion_credential_impl.hpp b/sdk/identity/azure-identity/src/private/client_assertion_credential_impl.hpp new file mode 100644 index 000000000..58144d753 --- /dev/null +++ b/sdk/identity/azure-identity/src/private/client_assertion_credential_impl.hpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * @file + * @brief Client Assertion Credential and options. + */ + +#pragma once + +#include "azure/identity/client_assertion_credential.hpp" +#include "azure/identity/detail/client_credential_core.hpp" +#include "azure/identity/detail/token_cache.hpp" +#include "token_credential_impl.hpp" + +#include + +namespace Azure { namespace Identity { namespace _detail { + class TokenCredentialImpl; + + class ClientAssertionCredentialImpl final { + private: + std::function m_assertionCallback; + _detail::ClientCredentialCore m_clientCredentialCore; + std::unique_ptr m_tokenCredentialImpl; + std::string m_requestBody; + _detail::TokenCache m_tokenCache; + + public: + ClientAssertionCredentialImpl( + std::string const& credentialName, + std::string tenantId, + std::string clientId, + std::function assertionCallback, + ClientAssertionCredentialOptions const& options = {}); + + Core::Credentials::AccessToken GetToken( + std::string const& credentialName, + Core::Credentials::TokenRequestContext const& tokenRequestContext, + Core::Context const& context) const; + }; +}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp b/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp index 4e6d7d461..8a17b211c 100644 --- a/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp +++ b/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp @@ -25,5 +25,7 @@ namespace Azure { namespace Identity { namespace _detail { // ADFS is the Active Directory Federation Service, a tenant ID that is used in Azure Stack. static bool IsAdfs(std::string const& tenantId); + + static bool IsValidTenantId(std::string const& tenantId); }; }}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/tenant_id_resolver.cpp b/sdk/identity/azure-identity/src/tenant_id_resolver.cpp index ff88f56b0..fd89c78bb 100644 --- a/sdk/identity/azure-identity/src/tenant_id_resolver.cpp +++ b/sdk/identity/azure-identity/src/tenant_id_resolver.cpp @@ -56,3 +56,24 @@ bool TenantIdResolver::IsAdfs(std::string const& tenantId) { return StringExtensions::LocaleInvariantCaseInsensitiveEqual(tenantId, "adfs"); } + +bool TenantIdResolver::IsValidTenantId(std::string const& tenantId) +{ + const std::string allowedChars = ".-"; + if (tenantId.empty()) + { + return false; + } + for (auto const c : tenantId) + { + if (allowedChars.find(c) != std::string::npos) + { + continue; + } + if (!StringExtensions::IsAlphaNumeric(c)) + { + return false; + } + } + return true; +} diff --git a/sdk/identity/azure-identity/src/workload_identity_credential.cpp b/sdk/identity/azure-identity/src/workload_identity_credential.cpp index 2fe0cd213..59711966b 100644 --- a/sdk/identity/azure-identity/src/workload_identity_credential.cpp +++ b/sdk/identity/azure-identity/src/workload_identity_credential.cpp @@ -3,9 +3,9 @@ #include "azure/identity/workload_identity_credential.hpp" +#include "private/client_assertion_credential_impl.hpp" #include "private/identity_log.hpp" #include "private/tenant_id_resolver.hpp" -#include "private/token_credential_impl.hpp" #include @@ -23,35 +23,32 @@ using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Http::HttpMethod; using Azure::Identity::_detail::IdentityLog; using Azure::Identity::_detail::TenantIdResolver; -using Azure::Identity::_detail::TokenCredentialImpl; WorkloadIdentityCredential::WorkloadIdentityCredential( WorkloadIdentityCredentialOptions const& options) - : TokenCredential("WorkloadIdentityCredential"), m_clientCredentialCore( - options.TenantId, - options.AuthorityHost, - options.AdditionallyAllowedTenants) + : TokenCredential("WorkloadIdentityCredential") { std::string tenantId = options.TenantId; std::string clientId = options.ClientId; - std::string authorityHost = options.AuthorityHost; m_tokenFilePath = options.TokenFilePath; - if (!tenantId.empty() && !clientId.empty() && !m_tokenFilePath.empty()) + if (TenantIdResolver::IsValidTenantId(tenantId) && !clientId.empty() && !m_tokenFilePath.empty()) { - m_clientCredentialCore = Azure::Identity::_detail::ClientCredentialCore( - tenantId, authorityHost, options.AdditionallyAllowedTenants); - m_tokenCredentialImpl = std::make_unique(options); - m_requestBody - = std::string( - "grant_type=client_credentials" - "&client_assertion_type=" - "urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" // cspell:disable-line - "&client_id=") - + Url::Encode(clientId); + ClientAssertionCredentialOptions clientAssertionCredentialOptions{}; + // Get the options from the base class (including ClientOptions). + static_cast(clientAssertionCredentialOptions) + = options; + clientAssertionCredentialOptions.AuthorityHost = options.AuthorityHost; + clientAssertionCredentialOptions.AdditionallyAllowedTenants + = options.AdditionallyAllowedTenants; - IdentityLog::Write( - IdentityLog::Level::Informational, GetCredentialName() + " was created successfully."); + std::function callback + = [this](Context const& context) { return GetAssertion(context); }; + + // ClientAssertionCredential validates the tenant ID, client ID, and assertion callback and logs + // warning messages otherwise. + m_clientAssertionCredentialImpl = std::make_unique<_detail::ClientAssertionCredentialImpl>( + GetCredentialName(), tenantId, clientId, callback, clientAssertionCredentialOptions); } else { @@ -64,30 +61,26 @@ WorkloadIdentityCredential::WorkloadIdentityCredential( WorkloadIdentityCredential::WorkloadIdentityCredential( Core::Credentials::TokenCredentialOptions const& options) - : TokenCredential("WorkloadIdentityCredential"), - m_clientCredentialCore("", "", std::vector()) + : TokenCredential("WorkloadIdentityCredential") { std::string const tenantId = _detail::DefaultOptionValues::GetTenantId(); std::string const clientId = _detail::DefaultOptionValues::GetClientId(); m_tokenFilePath = _detail::DefaultOptionValues::GetFederatedTokenFile(); - if (!tenantId.empty() && !clientId.empty() && !m_tokenFilePath.empty()) + if (TenantIdResolver::IsValidTenantId(tenantId) && !clientId.empty() && !m_tokenFilePath.empty()) { - std::string const authorityHost = _detail::DefaultOptionValues::GetAuthorityHost(); + ClientAssertionCredentialOptions clientAssertionCredentialOptions{}; + // Get the options from the base class (including ClientOptions). + static_cast(clientAssertionCredentialOptions) + = options; - m_clientCredentialCore = Azure::Identity::_detail::ClientCredentialCore( - tenantId, authorityHost, std::vector()); - m_tokenCredentialImpl = std::make_unique(options); - m_requestBody - = std::string( - "grant_type=client_credentials" - "&client_assertion_type=" - "urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" // cspell:disable-line - "&client_id=") - + Url::Encode(clientId); + std::function callback + = [this](Context const& context) { return GetAssertion(context); }; - IdentityLog::Write( - IdentityLog::Level::Informational, GetCredentialName() + " was created successfully."); + // ClientAssertionCredential validates the tenant ID, client ID, and assertion callback and logs + // warning messages otherwise. + m_clientAssertionCredentialImpl = std::make_unique<_detail::ClientAssertionCredentialImpl>( + GetCredentialName(), tenantId, clientId, callback, clientAssertionCredentialOptions); } else { @@ -100,11 +93,21 @@ WorkloadIdentityCredential::WorkloadIdentityCredential( WorkloadIdentityCredential::~WorkloadIdentityCredential() = default; +std::string WorkloadIdentityCredential::GetAssertion(Context const&) const +{ + // Read the specified file's content, which is expected to be a Kubernetes service account + // token. Kubernetes is responsible for updating the file as service account tokens expire. + std::ifstream azureFederatedTokenFile(m_tokenFilePath); + std::string assertion( + (std::istreambuf_iterator(azureFederatedTokenFile)), std::istreambuf_iterator()); + return assertion; +} + AccessToken WorkloadIdentityCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - if (!m_tokenCredentialImpl) + if (!m_clientAssertionCredentialImpl) { auto const AuthUnavailable = GetCredentialName() + " authentication unavailable. "; @@ -116,43 +119,6 @@ AccessToken WorkloadIdentityCredential::GetToken( AuthUnavailable + "Azure Kubernetes environment is not set up correctly."); } - auto const tenantId = TenantIdResolver::Resolve( - m_clientCredentialCore.GetTenantId(), - tokenRequestContext, - m_clientCredentialCore.GetAdditionallyAllowedTenants()); - - auto const scopesStr - = m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes); - - // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument - // when they are being executed. They are not supposed to keep a reference to lambda argument to - // call it later. Therefore, any capture made here will outlive the possible time frame when the - // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() { - return m_tokenCredentialImpl->GetToken(context, false, [&]() { - auto body = m_requestBody; - if (!scopesStr.empty()) - { - body += "&scope=" + scopesStr; - } - - auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId); - - // Read the specified file's content, which is expected to be a Kubernetes service account - // token. Kubernetes is responsible for updating the file as service account tokens expire. - std::ifstream azureFederatedTokenFile(m_tokenFilePath); - std::string assertion( - (std::istreambuf_iterator(azureFederatedTokenFile)), - std::istreambuf_iterator()); - - body += "&client_assertion=" + Azure::Core::Url::Encode(assertion); - - auto request - = std::make_unique(HttpMethod::Post, requestUrl, body); - - request->HttpRequest.SetHeader("Host", requestUrl.GetHost()); - - return request; - }); - }); + return m_clientAssertionCredentialImpl->GetToken( + GetCredentialName(), tokenRequestContext, context); }