Identity: Add AzureCliCredentialOptions::Subscription (#6415)

* Identity: Add AzureCliCredentialOptions::Subscription

* Clang-format

* Id => ID

* Fix typo

* Apply suggestions from code review

Co-authored-by: Larry Osterman <LarryOsterman@users.noreply.github.com>

* Replace `decltype` with `auto` in test file

* + "If this is the name of a subscription, use its ID instead."

---------

Co-authored-by: Anton Kolesnyk <antkmsft@users.noreply.github.com>
Co-authored-by: Larry Osterman <LarryOsterman@users.noreply.github.com>
This commit is contained in:
Anton Kolesnyk 2025-02-13 15:26:10 -08:00 committed by GitHub
parent 092ac143ed
commit b40045e8e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 181 additions and 28 deletions

View File

@ -4,6 +4,8 @@
### Features Added
- Added `Subscription` to `AzureCliCredentialOptions` which allows the caller to specify an Azure subscription that does not match the current Azure CLI subscription.
### Breaking Changes
### Bugs Fixed

View File

@ -43,6 +43,12 @@ namespace Azure { namespace Identity {
* for any tenant in which the application is installed.
*/
std::vector<std::string> AdditionallyAllowedTenants;
/**
* @brief The name or ID of an Azure subscription. If not empty, it enables acquiring tokens for
* a subscription other than the Azure CLI's current subscription.
*/
std::string Subscription;
};
/**
@ -54,7 +60,12 @@ namespace Azure { namespace Identity {
final
#endif
: public Core::Credentials::TokenCredential {
#if !defined(_azure_TESTING_BUILD)
private:
#else
protected:
#endif
/** @brief The cache for the access token. */
_detail::TokenCache m_tokenCache;
@ -67,17 +78,22 @@ namespace Azure { namespace Identity {
/** @brief The CLI process timeout. */
DateTime::duration m_cliProcessTimeout;
/** @brief Subscription name or ID. */
std::string m_subscription;
private:
explicit AzureCliCredential(
Core::Credentials::TokenCredentialOptions const& options,
std::string tenantId,
DateTime::duration cliProcessTimeout,
std::vector<std::string> additionallyAllowedTenants);
std::vector<std::string> additionallyAllowedTenants,
std::string subscription);
void ThrowIfNotSafeCmdLineInput(
std::string const& input,
std::string const& allowedChars,
std::string const& description) const;
std::string const& description,
std::string const& details) const;
public:
/**
@ -111,7 +127,11 @@ namespace Azure { namespace Identity {
#else
protected:
#endif
virtual std::string GetAzCommand(std::string const& scopes, std::string const& tenantId) const;
virtual std::string GetAzCommand(
std::string const& scopes,
std::string const& tenantId,
std::string const& subscription) const;
virtual int GetLocalTimeToUtcDiffSeconds() const;
};

View File

@ -60,7 +60,8 @@ using Azure::Identity::_detail::TokenCredentialImpl;
void AzureCliCredential::ThrowIfNotSafeCmdLineInput(
std::string const& input,
std::string const& allowedChars,
std::string const& description) const
std::string const& description,
std::string const& details = {}) const
{
for (auto const c : input)
{
@ -71,8 +72,8 @@ void AzureCliCredential::ThrowIfNotSafeCmdLineInput(
if (!StringExtensions::IsAlphaNumeric(c))
{
throw AuthenticationException(
GetCredentialName() + ": Unsafe command line input found in " + description + ": "
+ input);
GetCredentialName() + ": Unsafe command line input found in " + description + ": " + input
+ details);
}
}
}
@ -80,10 +81,12 @@ AzureCliCredential::AzureCliCredential(
Core::Credentials::TokenCredentialOptions const& options,
std::string tenantId,
DateTime::duration cliProcessTimeout,
std::vector<std::string> additionallyAllowedTenants)
std::vector<std::string> additionallyAllowedTenants,
std::string subscription)
: TokenCredential("AzureCliCredential"),
m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)),
m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout))
m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout)),
m_subscription(std::move(subscription))
{
static_cast<void>(options);
@ -99,7 +102,8 @@ AzureCliCredential::AzureCliCredential(AzureCliCredentialOptions const& options)
options,
options.TenantId,
options.CliProcessTimeout,
options.AdditionallyAllowedTenants)
options.AdditionallyAllowedTenants,
options.Subscription)
{
}
@ -108,17 +112,26 @@ AzureCliCredential::AzureCliCredential(const Core::Credentials::TokenCredentialO
options,
AzureCliCredentialOptions{}.TenantId,
AzureCliCredentialOptions{}.CliProcessTimeout,
AzureCliCredentialOptions{}.AdditionallyAllowedTenants)
AzureCliCredentialOptions{}.AdditionallyAllowedTenants,
AzureCliCredentialOptions{}.Subscription)
{
}
std::string AzureCliCredential::GetAzCommand(std::string const& scopes, std::string const& tenantId)
const
std::string AzureCliCredential::GetAzCommand(
std::string const& scopes,
std::string const& tenantId,
std::string const& subscription) const
{
// The OAuth 2.0 RFC (https://datatracker.ietf.org/doc/html/rfc6749#section-3.3) allows space as
// well for a list of scopes, but that isn't currently required.
ThrowIfNotSafeCmdLineInput(scopes, ".-:/_", "Scopes");
ThrowIfNotSafeCmdLineInput(tenantId, ".-", "TenantID");
ThrowIfNotSafeCmdLineInput(
subscription,
".-_ ",
"Subscription",
". If this is the name of a subscription, use its ID instead.");
std::string command = "az account get-access-token --output json --scope \"" + scopes + "\"";
if (!tenantId.empty())
@ -126,6 +139,11 @@ std::string AzureCliCredential::GetAzCommand(std::string const& scopes, std::str
command += " --tenant \"" + tenantId + "\"";
}
if (!subscription.empty())
{
command += " --subscription \"" + subscription + "\"";
}
return command;
}
@ -164,7 +182,7 @@ AccessToken AzureCliCredential::GetToken(
auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false);
auto const tenantId
= TenantIdResolver::Resolve(m_tenantId, tokenRequestContext, m_additionallyAllowedTenants);
auto const command = GetAzCommand(scopes, tenantId);
auto const command = GetAzCommand(scopes, tenantId, m_subscription);
// TokenCache::GetToken() can only use the lambda argument when they are being executed. They
// are not supposed to keep a reference to lambda argument to call it later. Therefore, any

View File

@ -53,10 +53,14 @@ private:
std::string m_command;
int m_localTimeToUtcDiffSeconds = 0;
std::string GetAzCommand(std::string const& resource, std::string const& tenantId) const override
std::string GetAzCommand(
std::string const& resource,
std::string const& tenantId,
std::string const& subscription) const override
{
static_cast<void>(resource);
static_cast<void>(tenantId);
static_cast<void>(subscription);
return m_command;
}
@ -76,13 +80,17 @@ public:
{
}
std::string GetOriginalAzCommand(std::string const& resource, std::string const& tenantId) const
std::string GetOriginalAzCommand(
std::string const& resource,
std::string const& tenantId,
std::string const& subscription) const
{
return AzureCliCredential::GetAzCommand(resource, tenantId);
return AzureCliCredential::GetAzCommand(resource, tenantId, subscription);
}
decltype(m_tenantId) const& GetTenantId() const { return m_tenantId; }
decltype(m_cliProcessTimeout) const& GetCliProcessTimeout() const { return m_cliProcessTimeout; }
auto const& GetTenantId() const { return m_tenantId; }
auto const& GetSubscription() const { return m_subscription; }
auto const& GetCliProcessTimeout() const { return m_cliProcessTimeout; }
void SetLocalTimeToUtcDiffSeconds(int diff) { m_localTimeToUtcDiffSeconds = diff; }
};
@ -329,6 +337,7 @@ TEST(AzureCliCredential, Defaults)
AzureCliTestCredential azCliCred({});
EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId);
EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout);
EXPECT_EQ(azCliCred.GetSubscription(), DefaultOptions.Subscription);
}
{
@ -336,6 +345,7 @@ TEST(AzureCliCredential, Defaults)
AzureCliTestCredential azCliCred({}, options);
EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId);
EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout);
EXPECT_EQ(azCliCred.GetSubscription(), DefaultOptions.Subscription);
}
}
@ -343,11 +353,13 @@ TEST(AzureCliCredential, Defaults)
AzureCliCredentialOptions options;
options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF";
options.CliProcessTimeout = std::chrono::seconds(12345);
options.Subscription = "Azure Sub_scrip-t1.0n";
AzureCliTestCredential azCliCred({}, options);
EXPECT_EQ(azCliCred.GetTenantId(), "01234567-89AB-CDEF-0123-456789ABCDEF");
EXPECT_EQ(azCliCred.GetCliProcessTimeout(), std::chrono::seconds(12345));
EXPECT_EQ(azCliCred.GetSubscription(), "Azure Sub_scrip-t1.0n");
}
}
@ -356,10 +368,18 @@ TEST(AzureCliCredential, CmdLine)
AzureCliTestCredential azCliCred({});
auto const cmdLineWithoutTenant
= azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {});
= azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {}, {});
auto const cmdLineWithTenant = azCliCred.GetOriginalAzCommand(
"https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF");
"https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF", {});
auto const cmdLineWithoutTenantAndWithSubscription = azCliCred.GetOriginalAzCommand(
"https://storage.azure.com/.default", {}, "Azure Sub_scrip-t1.0n");
auto const cmdLineWithTenantAndWithSubscription = azCliCred.GetOriginalAzCommand(
"https://storage.azure.com/.default",
"01234567-89AB-CDEF-0123-456789ABCDEF",
"Azure Sub_scrip-t1.0n");
EXPECT_EQ(
cmdLineWithoutTenant,
@ -369,16 +389,27 @@ TEST(AzureCliCredential, CmdLine)
cmdLineWithTenant,
"az account get-access-token --output json --scope \"https://storage.azure.com/.default\""
" --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\"");
EXPECT_EQ(
cmdLineWithoutTenantAndWithSubscription,
"az account get-access-token --output json --scope \"https://storage.azure.com/.default\""
" --subscription \"Azure Sub_scrip-t1.0n\"");
EXPECT_EQ(
cmdLineWithTenantAndWithSubscription,
"az account get-access-token --output json --scope \"https://storage.azure.com/.default\""
" --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\""
" --subscription \"Azure Sub_scrip-t1.0n\"");
}
TEST(AzureCliCredential, UnsafeChars)
{
std::string const Exploit = std::string("\" | echo OWNED | ") + InfiniteCommand + " | echo \"";
std::string const Unsafe = std::string("\" | echo UNSAFE | ") + InfiniteCommand + " | echo \"";
{
AzureCliCredentialOptions options;
options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF";
options.TenantId += Exploit;
options.TenantId += Unsafe;
AzureCliCredential azCliCred(options);
TokenRequestContext trc;
@ -392,10 +423,21 @@ TEST(AzureCliCredential, UnsafeChars)
AzureCliCredential azCliCred(options);
TokenRequestContext trc;
trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Exploit);
trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Unsafe);
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
}
{
AzureCliCredentialOptions options;
options.Subscription = "Azure Sub_scrip-t1.0n";
options.Subscription += Unsafe;
AzureCliCredential azCliCred(options);
TokenRequestContext trc;
trc.Scopes.push_back(std::string("https://storage.azure.com/.default"));
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
}
}
class ParameterizedTestForDisallowedChars : public ::testing::TestWithParam<std::string> {
@ -418,14 +460,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId)
trc.Scopes.push_back(std::string("https://storage.azure.com/.default"));
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what();
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in TenantID: " + options.TenantId);
}
EXPECT_TRUE(exceptionThrown);
}
// Tenant ID test via TokenRequestContext, using a wildcard for AdditionallyAllowedTenants.
@ -440,14 +488,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId)
trc.TenantId = InvalidValue;
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what();
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in TenantID: " + trc.TenantId);
}
EXPECT_TRUE(exceptionThrown);
}
// Tenant ID test via TokenRequestContext, using a specific AdditionallyAllowedTenants value.
@ -461,14 +515,20 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId)
trc.TenantId = InvalidValue;
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what();
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in TenantID: " + trc.TenantId);
}
EXPECT_TRUE(exceptionThrown);
}
// Scopes test via TokenRequestContext.
@ -481,14 +541,49 @@ TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId)
trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + InvalidValue);
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what();
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in Scopes: " + trc.Scopes.at(0));
}
EXPECT_TRUE(exceptionThrown);
}
if (InvalidValue != " ")
{
AzureCliCredentialOptions options;
options.Subscription = "Azure Sub_scrip-t1.0n";
options.Subscription += InvalidValue;
AzureCliCredential azCliCred(options);
TokenRequestContext trc;
trc.Scopes.push_back(std::string("https://storage.azure.com/.default"));
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in Subscription: "
+ options.Subscription
+ ". If this is the name of a subscription, use its ID instead.");
}
EXPECT_TRUE(exceptionThrown);
}
}
@ -516,14 +611,20 @@ TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId)
trc.Scopes.push_back(std::string("https://storage.azure.com/.default"));
EXPECT_THROW(static_cast<void>(azCliCred.GetToken(trc, {})), AuthenticationException);
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what();
exceptionThrown = true;
EXPECT_EQ(
e.what(),
"AzureCliCredential: Unsafe command line input found in TenantID: " + options.TenantId);
}
EXPECT_TRUE(exceptionThrown);
}
{
@ -536,14 +637,18 @@ TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId)
std::string("https://storage.azure.com/.default") + ValidScopeButNotTenantId);
// We expect the GetToken to fail, but not because of the unsafe chars.
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
exceptionThrown = true;
EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what();
}
EXPECT_TRUE(exceptionThrown);
}
}
@ -571,14 +676,18 @@ TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId)
trc.Scopes.push_back(std::string("https://storage.azure.com/.default"));
// We expect the GetToken to fail, but not because of the unsafe chars.
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
exceptionThrown = true;
EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what();
}
EXPECT_TRUE(exceptionThrown);
}
{
@ -590,14 +699,18 @@ TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId)
trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + ValidChars);
// We expect the GetToken to fail, but not because of the unsafe chars.
auto exceptionThrown = false;
try
{
auto const token = azCliCred.GetToken(trc, {});
}
catch (AuthenticationException const& e)
{
exceptionThrown = true;
EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what();
}
EXPECT_TRUE(exceptionThrown);
}
}