From b40045e8e9843ecaf0c7ad69e45d2d1078dad831 Mon Sep 17 00:00:00 2001 From: Anton Kolesnyk <41349689+antkmsft@users.noreply.github.com> Date: Thu, 13 Feb 2025 15:26:10 -0800 Subject: [PATCH] Identity: Add AzureCliCredentialOptions::Subscription (#6415) * Identity: Add AzureCliCredentialOptions::Subscription * Clang-format * Id => ID * Fix typo * Apply suggestions from code review Co-authored-by: Larry Osterman * Replace `decltype` with `auto` in test file * + "If this is the name of a subscription, use its ID instead." --------- Co-authored-by: Anton Kolesnyk Co-authored-by: Larry Osterman --- sdk/identity/azure-identity/CHANGELOG.md | 2 + .../azure/identity/azure_cli_credential.hpp | 26 +++- .../src/azure_cli_credential.cpp | 38 +++-- .../test/ut/azure_cli_credential_test.cpp | 143 ++++++++++++++++-- 4 files changed, 181 insertions(+), 28 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 5a6ed7c25..3ce370ec8 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added `Subscription` to `AzureCliCredentialOptions` which allows the caller to specify an Azure subscription that does not match the current Azure CLI subscription. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp index 1b1beb3a4..245ca006d 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -43,6 +43,12 @@ namespace Azure { namespace Identity { * for any tenant in which the application is installed. */ std::vector AdditionallyAllowedTenants; + + /** + * @brief The name or ID of an Azure subscription. If not empty, it enables acquiring tokens for + * a subscription other than the Azure CLI's current subscription. + */ + std::string Subscription; }; /** @@ -54,7 +60,12 @@ namespace Azure { namespace Identity { final #endif : public Core::Credentials::TokenCredential { +#if !defined(_azure_TESTING_BUILD) + private: +#else protected: +#endif + /** @brief The cache for the access token. */ _detail::TokenCache m_tokenCache; @@ -67,17 +78,22 @@ namespace Azure { namespace Identity { /** @brief The CLI process timeout. */ DateTime::duration m_cliProcessTimeout; + /** @brief Subscription name or ID. */ + std::string m_subscription; + private: explicit AzureCliCredential( Core::Credentials::TokenCredentialOptions const& options, std::string tenantId, DateTime::duration cliProcessTimeout, - std::vector additionallyAllowedTenants); + std::vector additionallyAllowedTenants, + std::string subscription); void ThrowIfNotSafeCmdLineInput( std::string const& input, std::string const& allowedChars, - std::string const& description) const; + std::string const& description, + std::string const& details) const; public: /** @@ -111,7 +127,11 @@ namespace Azure { namespace Identity { #else protected: #endif - virtual std::string GetAzCommand(std::string const& scopes, std::string const& tenantId) const; + virtual std::string GetAzCommand( + std::string const& scopes, + std::string const& tenantId, + std::string const& subscription) const; + virtual int GetLocalTimeToUtcDiffSeconds() const; }; diff --git a/sdk/identity/azure-identity/src/azure_cli_credential.cpp b/sdk/identity/azure-identity/src/azure_cli_credential.cpp index 9d2944fc5..1c6744d17 100644 --- a/sdk/identity/azure-identity/src/azure_cli_credential.cpp +++ b/sdk/identity/azure-identity/src/azure_cli_credential.cpp @@ -60,7 +60,8 @@ using Azure::Identity::_detail::TokenCredentialImpl; void AzureCliCredential::ThrowIfNotSafeCmdLineInput( std::string const& input, std::string const& allowedChars, - std::string const& description) const + std::string const& description, + std::string const& details = {}) const { for (auto const c : input) { @@ -71,8 +72,8 @@ void AzureCliCredential::ThrowIfNotSafeCmdLineInput( if (!StringExtensions::IsAlphaNumeric(c)) { throw AuthenticationException( - GetCredentialName() + ": Unsafe command line input found in " + description + ": " - + input); + GetCredentialName() + ": Unsafe command line input found in " + description + ": " + input + + details); } } } @@ -80,10 +81,12 @@ AzureCliCredential::AzureCliCredential( Core::Credentials::TokenCredentialOptions const& options, std::string tenantId, DateTime::duration cliProcessTimeout, - std::vector additionallyAllowedTenants) + std::vector additionallyAllowedTenants, + std::string subscription) : TokenCredential("AzureCliCredential"), m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)), - m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout)) + m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout)), + m_subscription(std::move(subscription)) { static_cast(options); @@ -99,7 +102,8 @@ AzureCliCredential::AzureCliCredential(AzureCliCredentialOptions const& options) options, options.TenantId, options.CliProcessTimeout, - options.AdditionallyAllowedTenants) + options.AdditionallyAllowedTenants, + options.Subscription) { } @@ -108,17 +112,26 @@ AzureCliCredential::AzureCliCredential(const Core::Credentials::TokenCredentialO options, AzureCliCredentialOptions{}.TenantId, AzureCliCredentialOptions{}.CliProcessTimeout, - AzureCliCredentialOptions{}.AdditionallyAllowedTenants) + AzureCliCredentialOptions{}.AdditionallyAllowedTenants, + AzureCliCredentialOptions{}.Subscription) { } -std::string AzureCliCredential::GetAzCommand(std::string const& scopes, std::string const& tenantId) - const +std::string AzureCliCredential::GetAzCommand( + std::string const& scopes, + std::string const& tenantId, + std::string const& subscription) const { // The OAuth 2.0 RFC (https://datatracker.ietf.org/doc/html/rfc6749#section-3.3) allows space as // well for a list of scopes, but that isn't currently required. ThrowIfNotSafeCmdLineInput(scopes, ".-:/_", "Scopes"); ThrowIfNotSafeCmdLineInput(tenantId, ".-", "TenantID"); + ThrowIfNotSafeCmdLineInput( + subscription, + ".-_ ", + "Subscription", + ". If this is the name of a subscription, use its ID instead."); + std::string command = "az account get-access-token --output json --scope \"" + scopes + "\""; if (!tenantId.empty()) @@ -126,6 +139,11 @@ std::string AzureCliCredential::GetAzCommand(std::string const& scopes, std::str command += " --tenant \"" + tenantId + "\""; } + if (!subscription.empty()) + { + command += " --subscription \"" + subscription + "\""; + } + return command; } @@ -164,7 +182,7 @@ AccessToken AzureCliCredential::GetToken( auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false); auto const tenantId = TenantIdResolver::Resolve(m_tenantId, tokenRequestContext, m_additionallyAllowedTenants); - auto const command = GetAzCommand(scopes, tenantId); + auto const command = GetAzCommand(scopes, tenantId, m_subscription); // TokenCache::GetToken() can only use the lambda argument when they are being executed. They // are not supposed to keep a reference to lambda argument to call it later. Therefore, any diff --git a/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp index e53b9ac40..a66a6cbae 100644 --- a/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp @@ -53,10 +53,14 @@ private: std::string m_command; int m_localTimeToUtcDiffSeconds = 0; - std::string GetAzCommand(std::string const& resource, std::string const& tenantId) const override + std::string GetAzCommand( + std::string const& resource, + std::string const& tenantId, + std::string const& subscription) const override { static_cast(resource); static_cast(tenantId); + static_cast(subscription); return m_command; } @@ -76,13 +80,17 @@ public: { } - std::string GetOriginalAzCommand(std::string const& resource, std::string const& tenantId) const + std::string GetOriginalAzCommand( + std::string const& resource, + std::string const& tenantId, + std::string const& subscription) const { - return AzureCliCredential::GetAzCommand(resource, tenantId); + return AzureCliCredential::GetAzCommand(resource, tenantId, subscription); } - decltype(m_tenantId) const& GetTenantId() const { return m_tenantId; } - decltype(m_cliProcessTimeout) const& GetCliProcessTimeout() const { return m_cliProcessTimeout; } + auto const& GetTenantId() const { return m_tenantId; } + auto const& GetSubscription() const { return m_subscription; } + auto const& GetCliProcessTimeout() const { return m_cliProcessTimeout; } void SetLocalTimeToUtcDiffSeconds(int diff) { m_localTimeToUtcDiffSeconds = diff; } }; @@ -329,6 +337,7 @@ TEST(AzureCliCredential, Defaults) AzureCliTestCredential azCliCred({}); EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + EXPECT_EQ(azCliCred.GetSubscription(), DefaultOptions.Subscription); } { @@ -336,6 +345,7 @@ TEST(AzureCliCredential, Defaults) AzureCliTestCredential azCliCred({}, options); EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + EXPECT_EQ(azCliCred.GetSubscription(), DefaultOptions.Subscription); } } @@ -343,11 +353,13 @@ TEST(AzureCliCredential, Defaults) AzureCliCredentialOptions options; options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; options.CliProcessTimeout = std::chrono::seconds(12345); + options.Subscription = "Azure Sub_scrip-t1.0n"; AzureCliTestCredential azCliCred({}, options); EXPECT_EQ(azCliCred.GetTenantId(), "01234567-89AB-CDEF-0123-456789ABCDEF"); EXPECT_EQ(azCliCred.GetCliProcessTimeout(), std::chrono::seconds(12345)); + EXPECT_EQ(azCliCred.GetSubscription(), "Azure Sub_scrip-t1.0n"); } } @@ -356,10 +368,18 @@ TEST(AzureCliCredential, CmdLine) AzureCliTestCredential azCliCred({}); auto const cmdLineWithoutTenant - = azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {}); + = azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {}, {}); auto const cmdLineWithTenant = azCliCred.GetOriginalAzCommand( - "https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF"); + "https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF", {}); + + auto const cmdLineWithoutTenantAndWithSubscription = azCliCred.GetOriginalAzCommand( + "https://storage.azure.com/.default", {}, "Azure Sub_scrip-t1.0n"); + + auto const cmdLineWithTenantAndWithSubscription = azCliCred.GetOriginalAzCommand( + "https://storage.azure.com/.default", + "01234567-89AB-CDEF-0123-456789ABCDEF", + "Azure Sub_scrip-t1.0n"); EXPECT_EQ( cmdLineWithoutTenant, @@ -369,16 +389,27 @@ TEST(AzureCliCredential, CmdLine) cmdLineWithTenant, "az account get-access-token --output json --scope \"https://storage.azure.com/.default\"" " --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\""); + + EXPECT_EQ( + cmdLineWithoutTenantAndWithSubscription, + "az account get-access-token --output json --scope \"https://storage.azure.com/.default\"" + " --subscription \"Azure Sub_scrip-t1.0n\""); + + EXPECT_EQ( + cmdLineWithTenantAndWithSubscription, + "az account get-access-token --output json --scope \"https://storage.azure.com/.default\"" + " --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\"" + " --subscription \"Azure Sub_scrip-t1.0n\""); } TEST(AzureCliCredential, UnsafeChars) { - std::string const Exploit = std::string("\" | echo OWNED | ") + InfiniteCommand + " | echo \""; + std::string const Unsafe = std::string("\" | echo UNSAFE | ") + InfiniteCommand + " | echo \""; { AzureCliCredentialOptions options; options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.TenantId += Exploit; + options.TenantId += Unsafe; AzureCliCredential azCliCred(options); TokenRequestContext trc; @@ -392,10 +423,21 @@ TEST(AzureCliCredential, UnsafeChars) AzureCliCredential azCliCred(options); TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Exploit); + trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Unsafe); EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); } + + { + AzureCliCredentialOptions options; + options.Subscription = "Azure Sub_scrip-t1.0n"; + options.Subscription += Unsafe; + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + } } class ParameterizedTestForDisallowedChars : public ::testing::TestWithParam { @@ -418,14 +460,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in TenantID: " + options.TenantId); } + + EXPECT_TRUE(exceptionThrown); } // Tenant ID test via TokenRequestContext, using a wildcard for AdditionallyAllowedTenants. @@ -440,14 +488,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) trc.TenantId = InvalidValue; EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in TenantID: " + trc.TenantId); } + + EXPECT_TRUE(exceptionThrown); } // Tenant ID test via TokenRequestContext, using a specific AdditionallyAllowedTenants value. @@ -461,14 +515,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) trc.TenantId = InvalidValue; EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in TenantID: " + trc.TenantId); } + + EXPECT_TRUE(exceptionThrown); } // Scopes test via TokenRequestContext. @@ -481,14 +541,49 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + InvalidValue); EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in Scopes: " + trc.Scopes.at(0)); } + + EXPECT_TRUE(exceptionThrown); + } + + if (InvalidValue != " ") + { + AzureCliCredentialOptions options; + options.Subscription = "Azure Sub_scrip-t1.0n"; + options.Subscription += InvalidValue; + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + + auto exceptionThrown = false; + try + { + auto const token = azCliCred.GetToken(trc, {}); + } + catch (AuthenticationException const& e) + { + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in Subscription: " + + options.Subscription + + ". If this is the name of a subscription, use its ID instead."); + } + + EXPECT_TRUE(exceptionThrown); } } @@ -516,14 +611,20 @@ TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId) trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + exceptionThrown = true; + EXPECT_EQ( + e.what(), + "AzureCliCredential: Unsafe command line input found in TenantID: " + options.TenantId); } + + EXPECT_TRUE(exceptionThrown); } { @@ -536,14 +637,18 @@ TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId) std::string("https://storage.azure.com/.default") + ValidScopeButNotTenantId); // We expect the GetToken to fail, but not because of the unsafe chars. + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { + exceptionThrown = true; EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); } + + EXPECT_TRUE(exceptionThrown); } } @@ -571,14 +676,18 @@ TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId) trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); // We expect the GetToken to fail, but not because of the unsafe chars. + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { + exceptionThrown = true; EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); } + + EXPECT_TRUE(exceptionThrown); } { @@ -590,14 +699,18 @@ TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId) trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + ValidChars); // We expect the GetToken to fail, but not because of the unsafe chars. + auto exceptionThrown = false; try { auto const token = azCliCred.GetToken(trc, {}); } catch (AuthenticationException const& e) { + exceptionThrown = true; EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); } + + EXPECT_TRUE(exceptionThrown); } }