diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 292f81a15..a5887d8b5 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -7,6 +7,7 @@ - Added the ability to ignore invalid certificate common name for TLS connections in WinHTTP transport. - Added `DisableTlsCertificateValidation` in `TransportOptions`. - Added `TokenCredential::GetCredentialName()` to be utilized in diagnostic messages. If you have any custom implementations of `TokenCredential`, it is recommended to pass the name of your credential to `TokenCredential` constructor. The old parameterless constructor is deprecated. +- Added support for challenge-based and multi-tenant authentication. ### Breaking Changes diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index 117c1a42c..154e145ce 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -83,6 +83,7 @@ set( inc/azure/core/http/policies/policy.hpp inc/azure/core/http/raw_response.hpp inc/azure/core/http/transport.hpp + inc/azure/core/internal/credentials/authorization_challenge_parser.hpp inc/azure/core/internal/client_options.hpp inc/azure/core/internal/contract.hpp inc/azure/core/internal/cryptography/sha_hash.hpp @@ -121,6 +122,7 @@ set( src/azure_assert.cpp src/base64.cpp src/context.cpp + src/credentials/authorization_challenge_parser.cpp src/cryptography/md5.cpp src/cryptography/sha_hash.cpp src/datetime.cpp diff --git a/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp b/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp index 710c192b9..d48c9fbbf 100644 --- a/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp +++ b/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp @@ -54,6 +54,12 @@ namespace Azure { namespace Core { namespace Credentials { * */ DateTime::duration MinimumExpiration = std::chrono::minutes(2); + + /** + * @brief Tenant ID. + * + */ + std::string TenantId; }; /** diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index f23c6ef82..6ce723c32 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -543,16 +543,14 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { * @brief Bearer Token authentication policy. * */ - class BearerTokenAuthenticationPolicy final : public HttpPolicy { + class BearerTokenAuthenticationPolicy : public HttpPolicy { private: std::shared_ptr const m_credential; Credentials::TokenRequestContext m_tokenRequestContext; mutable Credentials::AccessToken m_accessToken; mutable std::mutex m_accessTokenMutex; - - BearerTokenAuthenticationPolicy(BearerTokenAuthenticationPolicy const&) = delete; - void operator=(BearerTokenAuthenticationPolicy const&) = delete; + mutable Credentials::TokenRequestContext m_accessTokenContext; public: /** @@ -571,14 +569,36 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { std::unique_ptr Clone() const override { - return std::make_unique( - m_credential, m_tokenRequestContext); + return std::unique_ptr(new BearerTokenAuthenticationPolicy(*this)); } std::unique_ptr Send( Request& request, NextHttpPolicy nextPolicy, Context const& context) const override; + + protected: + BearerTokenAuthenticationPolicy(BearerTokenAuthenticationPolicy const& other) + : BearerTokenAuthenticationPolicy(other.m_credential, other.m_tokenRequestContext) + { + } + + void operator=(BearerTokenAuthenticationPolicy const&) = delete; + + virtual std::unique_ptr AuthorizeAndSendRequest( + Request& request, + NextHttpPolicy& nextPolicy, + Context const& context) const; + + virtual bool AuthorizeRequestOnChallenge( + std::string const& challenge, + Request& request, + Context const& context) const; + + void AuthenticateAndAuthorizeRequest( + Request& request, + Credentials::TokenRequestContext const& tokenRequestContext, + Context const& context) const; }; /** diff --git a/sdk/core/azure-core/inc/azure/core/internal/credentials/authorization_challenge_parser.hpp b/sdk/core/azure-core/inc/azure/core/internal/credentials/authorization_challenge_parser.hpp new file mode 100644 index 000000000..299497531 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/internal/credentials/authorization_challenge_parser.hpp @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Parser for challenge-based authentication policy. + */ + +#pragma once + +#include "azure/core/http/raw_response.hpp" + +#include + +namespace Azure { namespace Core { namespace Credentials { + namespace _detail { + class AuthorizationChallengeHelper final { + public: + static std::string const& GetChallenge(Http::RawResponse const& response); + }; + } // namespace _detail + + namespace _internal { + class AuthorizationChallengeParser final { + private: + AuthorizationChallengeParser() = delete; + ~AuthorizationChallengeParser() = delete; + + public: + /** + * @brief Gets challenge parameter from authentication challenge. + * + * @param challenge Authentication challenge. + * @param challengeScheme The challenge scheme containing the \p challengeParameter. + * @param challengeParameter The parameter key name containing the value to return. + * + * @return Challenge parameter value. + */ + static std::string GetChallengeParameter( + std::string const& challenge, + std::string const& challengeScheme, + std::string const& challengeParameter); + }; + } // namespace _internal +}}} // namespace Azure::Core::Credentials diff --git a/sdk/core/azure-core/src/credentials/authorization_challenge_parser.cpp b/sdk/core/azure-core/src/credentials/authorization_challenge_parser.cpp new file mode 100644 index 000000000..f77be1e3b --- /dev/null +++ b/sdk/core/azure-core/src/credentials/authorization_challenge_parser.cpp @@ -0,0 +1,274 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/internal/credentials/authorization_challenge_parser.hpp" + +#include "azure/core/internal/strings.hpp" + +#include + +using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper; +using Azure::Core::Credentials::_internal::AuthorizationChallengeParser; + +using Azure::Core::_internal::StringExtensions; +using Azure::Core::Http::HttpStatusCode; +using Azure::Core::Http::RawResponse; + +// The parser is implemented to closely mimic the logic in .NET SDK: +// https://github.com/Azure/azure-sdk-for-net/blob/b7efe81fe69e020df74853b0f501ca6ccd5c94a1/sdk/core/Azure.Core/src/Shared/AuthorizationChallengeParser.cs +namespace { +class StringSpan final { + std::string const* m_stringPtr; + int m_startPos; + int m_endPosExclusive; + +public: + StringSpan(std::string const* stringPtr = nullptr); + StringSpan(StringSpan const&) = default; + StringSpan& operator=(StringSpan const&) = default; + + int Length() const; + std::string ToString() const; + StringSpan Slice(int start) const; + StringSpan Slice(int start, int length) const; + StringSpan TrimStart(std::set const& chars) const; + StringSpan Trim(std::set const& chars) const; + int IndexOfAny(std::set const& chars) const; + bool CaseInsensitiveEquals(StringSpan const& other) const; +}; + +bool TryGetNextChallenge(StringSpan& headerValue, StringSpan& challengeKey); +bool TryGetNextParameter(StringSpan& headerValue, StringSpan& paramKey, StringSpan& paramValue); + +std::string const EmptyString; +} // namespace + +std::string const& AuthorizationChallengeHelper::GetChallenge(Http::RawResponse const& response) +{ + // See RFC7235 (https://www.rfc-editor.org/rfc/rfc7235#section-4.1) + if (response.GetStatusCode() == HttpStatusCode::Unauthorized) + { + auto const& headers = response.GetHeaders(); + auto const wwwAuthHeader = headers.find("WWW-Authenticate"); + if (wwwAuthHeader != headers.end()) + { + return wwwAuthHeader->second; + } + } + + return EmptyString; +} + +std::string AuthorizationChallengeParser::GetChallengeParameter( + std::string const& challenge, + std::string const& challengeScheme, + std::string const& challengeParameter) +{ + StringSpan bearer = &challengeScheme; + StringSpan claims = &challengeParameter; + StringSpan headerSpan = &challenge; + + // Iterate through each challenge value. + StringSpan challengeKey; + while (TryGetNextChallenge(headerSpan, challengeKey)) + { + // Enumerate each key=value parameter until we find the 'claims' key on the 'Bearer' + // challenge. + StringSpan key; + StringSpan value; + while (TryGetNextParameter(headerSpan, key, value)) + { + if (challengeKey.CaseInsensitiveEquals(bearer) && key.CaseInsensitiveEquals(claims)) + { + return value.ToString(); + } + } + } + + return {}; +} + +namespace { +std::set const Space = {' '}; +std::set const SpaceOrComma = {' ', ','}; +std::set const Separator = {'='}; +std::set const Quote = {'\"'}; + +bool TryGetNextChallenge(StringSpan& headerValue, StringSpan& challengeKey) +{ + challengeKey = StringSpan(); + + headerValue = headerValue.TrimStart(Space); + int endOfChallengeKey = headerValue.IndexOfAny(Space); + + if (endOfChallengeKey < 0) + { + return false; + } + + challengeKey = headerValue.Slice(0, endOfChallengeKey); + + // Slice the challenge key from the headerValue + headerValue = headerValue.Slice(endOfChallengeKey + 1); + + return true; +} + +bool TryGetNextParameter(StringSpan& headerValue, StringSpan& paramKey, StringSpan& paramValue) +{ + paramKey = StringSpan(); + paramValue = StringSpan(); + + // Trim any separator prefixes. + headerValue = headerValue.TrimStart(SpaceOrComma); + + int nextSpace = headerValue.IndexOfAny(Space); + int nextSeparator = headerValue.IndexOfAny(Separator); + + if (nextSpace < nextSeparator && nextSpace != -1) + { + // we encountered another challenge value. + return false; + } + + if (nextSeparator < 0) + { + return false; + } + + // Get the paramKey. + paramKey = headerValue.Slice(0, nextSeparator).Trim(Space); + + // Slice to remove the 'paramKey=' from the parameters. + headerValue = headerValue.Slice(nextSeparator + 1); + + // The start of paramValue will usually be a quoted string. Find the first quote. + int quoteIndex = headerValue.IndexOfAny(Quote); + + // Get the paramValue, which is delimited by the trailing quote. + headerValue = headerValue.Slice(quoteIndex + 1); + if (quoteIndex >= 0) + { + // The values are quote wrapped + paramValue = headerValue.Slice(0, headerValue.IndexOfAny(Quote)); + } + else + { + // the values are not quote wrapped (storage is one example of this) + // either find the next space indicating the delimiter to the next value, or go to the end + // since this is the last value. + int trailingDelimiterIndex = headerValue.IndexOfAny(SpaceOrComma); + if (trailingDelimiterIndex >= 0) + { + paramValue = headerValue.Slice(0, trailingDelimiterIndex); + } + else + { + paramValue = headerValue; + } + } + + // Slice to remove the '"paramValue"' from the parameters. + if (!headerValue.CaseInsensitiveEquals(paramValue)) + headerValue = headerValue.Slice(paramValue.Length() + 1); + + return true; +} + +StringSpan::StringSpan(std::string const* stringPtr) + : m_stringPtr(stringPtr), m_startPos(0), + m_endPosExclusive(stringPtr ? static_cast(stringPtr->size()) : 0) +{ +} + +int StringSpan::Length() const { return m_endPosExclusive - m_startPos; } + +std::string StringSpan::ToString() const +{ + return m_stringPtr + ? m_stringPtr->substr(static_cast(m_startPos), static_cast(Length())) + : std::string{}; +} + +StringSpan StringSpan::Slice(int start) const +{ + StringSpan result = *this; + result.m_startPos += start; + return result; +} + +StringSpan StringSpan::Slice(int start, int length) const +{ + auto result = Slice(start); + result.m_endPosExclusive = result.m_startPos + length; + return result; +} + +StringSpan StringSpan::TrimStart(std::set const& chars) const +{ + StringSpan result = *this; + + auto pos = result.m_startPos; + for (; pos < result.m_endPosExclusive; ++pos) + { + if (chars.find(result.m_stringPtr->operator[](static_cast(pos))) == chars.end()) + { + break; + } + } + + result.m_startPos = pos; + return result; +} + +StringSpan StringSpan::Trim(std::set const& chars) const +{ + StringSpan result = TrimStart(chars); + + auto endPos = m_endPosExclusive; + for (; endPos > result.m_startPos; --endPos) + { + if (chars.find(result.m_stringPtr->operator[](static_cast(endPos - 1))) == chars.end()) + { + break; + } + } + + result.m_endPosExclusive = endPos; + return result; +} + +int StringSpan::IndexOfAny(std::set const& chars) const +{ + for (auto pos = m_startPos; pos < m_endPosExclusive; ++pos) + { + if (chars.find(m_stringPtr->operator[](static_cast(pos))) != chars.end()) + { + return pos - m_startPos; + } + } + + return -1; +} + +bool StringSpan::CaseInsensitiveEquals(StringSpan const& other) const +{ + auto const length = Length(); + if (length != other.Length()) + { + return false; + } + + for (auto offset = 0; offset < length; ++offset) + { + if (StringExtensions::ToLower(m_stringPtr->operator[](static_cast(m_startPos + offset))) + != StringExtensions::ToLower( + other.m_stringPtr->operator[](static_cast(other.m_startPos + offset)))) + { + return false; + } + } + + return true; +} +} // namespace diff --git a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp index 0f73918f6..e8fd9ffc6 100644 --- a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp +++ b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp @@ -1,16 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-License-Identifier: MIT -#include "azure/core/credentials/credentials.hpp" #include "azure/core/http/policies/policy.hpp" +#include "azure/core/credentials/credentials.hpp" +#include "azure/core/internal/credentials/authorization_challenge_parser.hpp" + #include +using Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy; + using Azure::Core::Context; -using namespace Azure::Core::Http; -using namespace Azure::Core::Http::Policies; -using namespace Azure::Core::Http::Policies::_internal; using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenRequestContext; +using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper; +using Azure::Core::Http::RawResponse; +using Azure::Core::Http::Policies::NextHttpPolicy; std::unique_ptr BearerTokenAuthenticationPolicy::Send( Request& request, @@ -23,17 +28,55 @@ std::unique_ptr BearerTokenAuthenticationPolicy::Send( "Bearer token authentication is not permitted for non TLS protected (https) endpoints."); } + auto result = AuthorizeAndSendRequest(request, nextPolicy, context); { - std::lock_guard lock(m_accessTokenMutex); - - if (std::chrono::system_clock::now() - > (m_accessToken.ExpiresOn - m_tokenRequestContext.MinimumExpiration)) + auto const& response = *result; + auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response); + if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context)) { - m_accessToken = m_credential->GetToken(m_tokenRequestContext, context); + result = nextPolicy.Send(request, context); } - - request.SetHeader("authorization", "Bearer " + m_accessToken.Token); } + return result; +} + +std::unique_ptr BearerTokenAuthenticationPolicy::AuthorizeAndSendRequest( + Request& request, + NextHttpPolicy& nextPolicy, + Context const& context) const +{ + AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context); return nextPolicy.Send(request, context); } + +bool BearerTokenAuthenticationPolicy::AuthorizeRequestOnChallenge( + std::string const& challenge, + Request& request, + Context const& context) const +{ + static_cast(challenge); + static_cast(request); + static_cast(context); + + return false; +} + +void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest( + Request& request, + Credentials::TokenRequestContext const& tokenRequestContext, + Context const& context) const +{ + std::lock_guard lock(m_accessTokenMutex); + + if (tokenRequestContext.TenantId != m_accessTokenContext.TenantId + || tokenRequestContext.Scopes != m_accessTokenContext.Scopes + || std::chrono::system_clock::now() + > (m_accessToken.ExpiresOn - tokenRequestContext.MinimumExpiration)) + { + m_accessToken = m_credential->GetToken(tokenRequestContext, context); + m_accessTokenContext = tokenRequestContext; + } + + request.SetHeader("authorization", "Bearer " + m_accessToken.Token); +} diff --git a/sdk/core/azure-core/test/ut/CMakeLists.txt b/sdk/core/azure-core/test/ut/CMakeLists.txt index 864453d32..fb1617e75 100644 --- a/sdk/core/azure-core/test/ut/CMakeLists.txt +++ b/sdk/core/azure-core/test/ut/CMakeLists.txt @@ -42,6 +42,7 @@ endif() add_executable ( azure-core-test + authorization_challenge_parser_test.cpp azure_core_test.cpp base64_test.cpp bearer_token_authentication_policy_test.cpp diff --git a/sdk/core/azure-core/test/ut/authorization_challenge_parser_test.cpp b/sdk/core/azure-core/test/ut/authorization_challenge_parser_test.cpp new file mode 100644 index 000000000..312e0af30 --- /dev/null +++ b/sdk/core/azure-core/test/ut/authorization_challenge_parser_test.cpp @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include + +#include + +using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper; +using Azure::Core::Credentials::_internal::AuthorizationChallengeParser; + +using Azure::Core::Http::HttpStatusCode; +using Azure::Core::Http::RawResponse; + +namespace { +RawResponse CreateRawResponseWithWwwAuthHeader( + std::string const& value, + HttpStatusCode httpStatusCode = HttpStatusCode::Unauthorized) +{ + RawResponse result(1, 1, httpStatusCode, "Test"); + result.SetHeader("WWW-Authenticate", value); + return result; +} + +std::string GetChallengeParameterFromResponse( + RawResponse const& response, + std::string const& challengeScheme, + std::string const& challengeParameter) +{ + return AuthorizationChallengeParser::GetChallengeParameter( + AuthorizationChallengeHelper::GetChallenge(response), challengeScheme, challengeParameter); +} +} // namespace + +TEST(AuthorizationChallengeParser, Simple) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Bearer key=value"), "Bearer", "key"), + "value"); +} + +TEST(AuthorizationChallengeParser, EmptyString) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse(CreateRawResponseWithWwwAuthHeader(""), "Bearer", "key"), + ""); +} + +TEST(AuthorizationChallengeParser, Non401) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Bearer key=value", HttpStatusCode::Ok), + "Bearer", + "key"), + ""); +} + +TEST(AuthorizationChallengeParser, NoHeader) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + RawResponse(1, 1, HttpStatusCode::Unauthorized, "Test"), "Bearer", "key"), + ""); +} + +TEST(AuthorizationChallengeParser, KeyNotFound) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Bearer otherkey=value"), "Bearer", "key"), + ""); +} + +TEST(AuthorizationChallengeParser, SchemeNotFound) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Basic key=value"), "Bearer", "key"), + ""); +} + +TEST(AuthorizationChallengeParser, NotFoundForScheme) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Basic key=value, Bearer otherkey=value"), + "Bearer", + "key"), + ""); +} + +TEST(AuthorizationChallengeParser, MultiplSchemeMatch) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader( + "Basic key=value1, Bearer key=value2, Digest key=value3"), + "Bearer", + "key"), + "value2"); +} + +TEST(AuthorizationChallengeParser, Quoted) +{ + EXPECT_EQ( + GetChallengeParameterFromResponse( + CreateRawResponseWithWwwAuthHeader("Bearer key=\"v a l u e\""), "Bearer", "key"), + "v a l u e"); +} + +TEST(AuthorizationChallengeParser, CaeInsufficientClaimsChallenge) +{ + auto const response = CreateRawResponseWithWwwAuthHeader( + "Bearer realm=\"\", " + "authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", " + "client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", " + "claims=\"eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=\""); + + EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "realm"), ""); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"), + "https://login.microsoftonline.com/common/oauth2/authorize"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "client_id"), + "00000003-0000-0000-c000-000000000000"); + + EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "insufficient_claims"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "claims"), + "eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0="); +} + +TEST(AuthorizationChallengeParser, CaeSessionsRevokedClaimsChallenge) +{ + auto const response = CreateRawResponseWithWwwAuthHeader( + "Bearer authorization_uri=\"https://login.windows-ppe.net/\", error=\"invalid_token\", " + "error_description=\"User session has been revoked\", " + "claims=" + "\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + "\""); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"), + "https://login.windows-ppe.net/"); + + EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "invalid_token"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "error_description"), + "User session has been revoked"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "claims"), + "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="); +} + +TEST(AuthorizationChallengeParser, KeyVaultChallenge) +{ + auto const response = CreateRawResponseWithWwwAuthHeader( + "Bearer " + "authorization=\"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47\", " + "resource=\"https://vault.azure.net\""); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "authorization"), + "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "resource"), "https://vault.azure.net"); +} + +TEST(AuthorizationChallengeParser, ArmChallenge) +{ + auto const response = CreateRawResponseWithWwwAuthHeader( + "Bearer authorization_uri=\"https://login.windows.net/\", error=\"invalid_token\", " + "error_description=\"The authentication failed because of missing 'Authorization' header.\""); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"), + "https://login.windows.net/"); + + EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "invalid_token"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "error_description"), + "The authentication failed because of missing 'Authorization' header."); +} + +TEST(AuthorizationChallengeParser, StorageChallenge) +{ + auto const response = CreateRawResponseWithWwwAuthHeader( + "Bearer " + "authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/" + "oauth2/authorize resource_id=https://storage.azure.com"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"), + "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize"); + + EXPECT_EQ( + GetChallengeParameterFromResponse(response, "Bearer", "resource_id"), + "https://storage.azure.com"); +} diff --git a/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp b/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp index c38c6aa19..7f62e91c8 100644 --- a/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/bearer_token_authentication_policy_test.cpp @@ -8,34 +8,44 @@ #include +using Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy; + +using Azure::Core::Context; +using Azure::Core::Url; +using Azure::Core::Credentials::AccessToken; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredential; +using Azure::Core::Credentials::TokenRequestContext; +using Azure::Core::Http::HttpMethod; +using Azure::Core::Http::HttpStatusCode; +using Azure::Core::Http::RawResponse; +using Azure::Core::Http::Request; +using Azure::Core::Http::_internal::HttpPipeline; +using Azure::Core::Http::Policies::HttpPolicy; +using Azure::Core::Http::Policies::NextHttpPolicy; + namespace { -class TestTokenCredential final : public Azure::Core::Credentials::TokenCredential { +class TestTokenCredential final : public TokenCredential { private: - std::shared_ptr m_accessToken; + std::shared_ptr m_accessToken; public: - explicit TestTokenCredential( - std::shared_ptr accessToken) + explicit TestTokenCredential(std::shared_ptr accessToken) : TokenCredential("TestTokenCredential"), m_accessToken(accessToken) { } - Azure::Core::Credentials::AccessToken GetToken( - Azure::Core::Credentials::TokenRequestContext const&, - Azure::Core::Context const&) const override + AccessToken GetToken(TokenRequestContext const&, Context const&) const override { return *m_accessToken; } }; -class TestTransportPolicy final : public Azure::Core::Http::Policies::HttpPolicy { +class TestTransportPolicy final : public HttpPolicy { public: - std::unique_ptr Send( - Azure::Core::Http::Request&, - Azure::Core::Http::Policies::NextHttpPolicy, - Azure::Core::Context const&) const override + std::unique_ptr Send(Request&, NextHttpPolicy, Context const&) const override { - return nullptr; + return std::make_unique(1, 1, HttpStatusCode::Ok, "TestStatus"); } std::unique_ptr Clone() const override @@ -43,34 +53,31 @@ public: return std::make_unique(*this); } }; - } // namespace TEST(BearerTokenAuthenticationPolicy, InitialGet) { using namespace std::chrono_literals; - auto accessToken = std::make_shared(); + auto accessToken = std::make_shared(); - std::vector> policies; + std::vector> policies; - Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + TokenRequestContext tokenRequestContext; tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; - policies.emplace_back( - std::make_unique( - std::make_shared(accessToken), tokenRequestContext)); + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + HttpPipeline pipeline(policies); { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); { auto const headers = request.GetHeaders(); @@ -84,37 +91,34 @@ TEST(BearerTokenAuthenticationPolicy, InitialGet) TEST(BearerTokenAuthenticationPolicy, ReuseWhileValid) { using namespace std::chrono_literals; - auto accessToken = std::make_shared(); + auto accessToken = std::make_shared(); - std::vector> policies; + std::vector> policies; - Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + TokenRequestContext tokenRequestContext; tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; - policies.emplace_back( - std::make_unique( - std::make_shared(accessToken), tokenRequestContext)); + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + HttpPipeline pipeline(policies); { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 5min}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); } { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); { auto const headers = request.GetHeaders(); @@ -128,37 +132,34 @@ TEST(BearerTokenAuthenticationPolicy, ReuseWhileValid) TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry) { using namespace std::chrono_literals; - auto accessToken = std::make_shared(); + auto accessToken = std::make_shared(); - std::vector> policies; + std::vector> policies; - Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + TokenRequestContext tokenRequestContext; tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; - policies.emplace_back( - std::make_unique( - std::make_shared(accessToken), tokenRequestContext)); + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + HttpPipeline pipeline(policies); { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 2min}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); } { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); { auto const headers = request.GetHeaders(); @@ -172,37 +173,34 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry) TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry) { using namespace std::chrono_literals; - auto accessToken = std::make_shared(); + auto accessToken = std::make_shared(); - std::vector> policies; + std::vector> policies; - Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + TokenRequestContext tokenRequestContext; tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; - policies.emplace_back( - std::make_unique( - std::make_shared(accessToken), tokenRequestContext)); + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + HttpPipeline pipeline(policies); { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now()}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); } { - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com")); + Request request(HttpMethod::Get, Url("https://www.azure.com")); *accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h}; - pipeline.Send(request, Azure::Core::Context()); + pipeline.Send(request, Context()); { auto const headers = request.GetHeaders(); @@ -213,30 +211,251 @@ TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry) } } -TEST(BearerTokenAuthenticationPolicy, HttpEndpoint) +TEST(BearerTokenAuthenticationPolicy, NonHttps) { using namespace std::chrono_literals; - auto accessToken = std::make_shared(); + auto accessToken = std::make_shared(); - std::vector> policies; + std::vector> policies; - Azure::Core::Credentials::TokenRequestContext tokenRequestContext; + TokenRequestContext tokenRequestContext; tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; - policies.emplace_back( - std::make_unique( - std::make_shared(accessToken), tokenRequestContext)); + policies.emplace_back(std::make_unique( + std::make_shared(accessToken), tokenRequestContext)); policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + HttpPipeline pipeline(policies); - Azure::Core::Http::Request request( - Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("http://www.azure.com")); + Request request(HttpMethod::Get, Url("http://www.azure.com")); *accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now()}; - EXPECT_THROW( - static_cast(pipeline.Send(request, Azure::Core::Context())), - Azure::Core::Credentials::AuthenticationException); + EXPECT_THROW(static_cast(pipeline.Send(request, Context())), AuthenticationException); +} + +namespace { +class TestBearerTokenAuthenticationPolicy final : public BearerTokenAuthenticationPolicy { +public: + TestBearerTokenAuthenticationPolicy( + std::shared_ptr credential, + TokenRequestContext tokenRequestContext) + : BearerTokenAuthenticationPolicy(credential, tokenRequestContext) + { + } + + std::unique_ptr Clone() const override + { + return std::make_unique(*this); + } + +protected: + bool AuthorizeRequestOnChallenge(std::string const&, Request&, Context const&) const override + { + EXPECT_FALSE("AuthorizeRequestOnChallenge() should not get called if AuthorizeAndSendRequest() " + "was successful."); + return false; + } +}; + +class TestTokenCredentialForBearerTokenAuthenticationPolicy final : public TokenCredential { +public: + explicit TestTokenCredentialForBearerTokenAuthenticationPolicy() + : TokenCredential("TestTokenCredentialForBearerTokenAuthenticationPolicy") + { + } + + AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const&) + const override + { + EXPECT_EQ(tokenRequestContext.Scopes.size(), 1); + EXPECT_EQ(tokenRequestContext.Scopes[0], "https://microsoft.com/.default"); + + EXPECT_TRUE(tokenRequestContext.TenantId.empty()); + + AccessToken result; + result.Token = "ACCESSTOKEN"; + return result; + } +}; + +class TestChallengeBasedAuthenticationPolicy final : public BearerTokenAuthenticationPolicy { +private: + bool m_successfulAuthOnChallenge; + +public: + TestChallengeBasedAuthenticationPolicy( + std::shared_ptr credential, + TokenRequestContext tokenRequestContext, + bool successfulAuthOnChallenge) + : BearerTokenAuthenticationPolicy(credential, tokenRequestContext), + m_successfulAuthOnChallenge(successfulAuthOnChallenge) + { + } + + std::unique_ptr Clone() const override + { + return std::make_unique(*this); + } + +protected: + std::unique_ptr AuthorizeAndSendRequest( + Request& request, + NextHttpPolicy& nextHttpPolicy, + Context const& context) const override + { + EXPECT_EQ(request.GetUrl().GetAbsoluteUrl(), "https://www.azure.com"); + + TokenRequestContext trc; + trc.Scopes = {"https://visualstudio.com/.default"}; + trc.TenantId = "TestTenantId1"; + + AuthenticateAndAuthorizeRequest(request, trc, context); + + // Simulate as if we got challenge-based response and not a HTTP 200. + static_cast(nextHttpPolicy.Send(request, context)); + auto response = std::make_unique(1, 1, HttpStatusCode::Unauthorized, "TestStatus"); + response->SetHeader("WWW-Authenticate", "TestChallenge"); + return response; + } + + bool AuthorizeRequestOnChallenge( + std::string const& challenge, + Request& request, + Context const& context) const override + { + EXPECT_EQ(challenge, "TestChallenge"); + + TokenRequestContext trc; + trc.Scopes = {"https://xbox.com/.default"}; + trc.TenantId = "TestTenantId2"; + + if (m_successfulAuthOnChallenge) + { + AuthenticateAndAuthorizeRequest(request, trc, context); + return true; + } + + return false; + } +}; + +class TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy final : public TokenCredential { +private: + mutable int m_invokedTimes; + +public: + explicit TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy() + : TokenCredential("TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy"), + m_invokedTimes(0) + { + } + + AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const&) + const override + { + ++m_invokedTimes; + + EXPECT_GE(m_invokedTimes, 1); + EXPECT_LE(m_invokedTimes, 2); + + if (m_invokedTimes == 1) + { + EXPECT_EQ(tokenRequestContext.Scopes.size(), 1); + EXPECT_EQ(tokenRequestContext.Scopes[0], "https://visualstudio.com/.default"); + EXPECT_EQ(tokenRequestContext.TenantId, "TestTenantId1"); + + AccessToken result; + result.Token = "ACCESSTOKEN1"; + return result; + } + + EXPECT_EQ(tokenRequestContext.Scopes.size(), 1); + EXPECT_EQ(tokenRequestContext.Scopes[0], "https://xbox.com/.default"); + EXPECT_EQ(tokenRequestContext.TenantId, "TestTenantId2"); + + AccessToken result; + result.Token = "ACCESSTOKEN2"; + return result; + } +}; +} // namespace + +TEST(BearerTokenAuthenticationPolicy, ChallengeBasedSupport) +{ + std::vector> policies; + + TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; + + policies.emplace_back(std::make_unique( + std::make_shared(), + tokenRequestContext)); + + policies.emplace_back(std::make_unique()); + + HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Url("https://www.azure.com")); + + static_cast(pipeline.Send(request, Context())); + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN"); +} + +TEST(BearerTokenAuthenticationPolicy, ChallengeBasedSuccess) +{ + std::vector> policies; + + TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; + + policies.emplace_back(std::make_unique( + std::make_shared(), + tokenRequestContext, + true)); + + policies.emplace_back(std::make_unique()); + + HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Url("https://www.azure.com")); + + auto const response = pipeline.Send(request, Context()); + EXPECT_EQ(response->GetStatusCode(), HttpStatusCode::Ok); + + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2"); +} + +TEST(BearerTokenAuthenticationPolicy, ChallengeBasedFailure) +{ + std::vector> policies; + + TokenRequestContext tokenRequestContext; + tokenRequestContext.Scopes = {"https://microsoft.com/.default"}; + + policies.emplace_back(std::make_unique( + std::make_shared(), + tokenRequestContext, + false)); + + policies.emplace_back(std::make_unique()); + + HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Url("https://www.azure.com")); + + auto const response = pipeline.Send(request, Context()); + EXPECT_EQ(response->GetStatusCode(), HttpStatusCode::Unauthorized); + + auto const headers = request.GetHeaders(); + auto const authHeader = headers.find("authorization"); + EXPECT_NE(authHeader, headers.end()); + EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1"); } diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 5ce4f5a8d..0b3062f9e 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added support for challenge-based and multi-tenant authentication. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index 25169d9ef..4cd89fabc 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -66,6 +66,7 @@ set( src/private/identity_log.hpp src/private/managed_identity_source.hpp src/private/package_version.hpp + src/private/tenant_id_resolver.hpp src/private/token_credential_impl.hpp src/azure_cli_credential.cpp src/chained_token_credential.cpp @@ -76,6 +77,7 @@ set( src/environment_credential.cpp src/managed_identity_credential.cpp src/managed_identity_source.cpp + src/tenant_id_resolver.cpp src/token_cache.cpp src/token_credential_impl.cpp ) diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp index 53aec4f32..0de97ff2f 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -17,6 +17,7 @@ #include #include +#include namespace Azure { namespace Identity { /** @@ -36,6 +37,13 @@ namespace Azure { namespace Identity { */ DateTime::duration CliProcessTimeout = std::chrono::seconds(13); // Value was taken from .NET SDK. + + /** + * @brief For multi-tenant applications, specifies additional tenants for which the credential + * may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens + * for any tenant in which the application is installed. + */ + std::vector AdditionallyAllowedTenants; }; /** @@ -49,14 +57,16 @@ namespace Azure { namespace Identity { : public Core::Credentials::TokenCredential { protected: _detail::TokenCache m_tokenCache; + std::vector m_additionallyAllowedTenants; std::string m_tenantId; DateTime::duration m_cliProcessTimeout; private: explicit AzureCliCredential( + Core::Credentials::TokenCredentialOptions const& options, std::string tenantId, DateTime::duration cliProcessTimeout, - Core::Credentials::TokenCredentialOptions const& options); + std::vector additionallyAllowedTenants); void ThrowIfNotSafeCmdLineInput(std::string const& input, std::string const& description) const; diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp index 669a849f2..06eb44a62 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_certificate_credential.hpp @@ -18,6 +18,7 @@ #include #include +#include namespace Azure { namespace Identity { namespace _detail { @@ -50,6 +51,13 @@ namespace Azure { namespace Identity { * https://docs.microsoft.com/azure/active-directory/develop/authentication-national-cloud. */ std::string AuthorityHost = _detail::ClientCredentialCore::AadGlobalAuthority; + + /** + * @brief For multi-tenant applications, specifies additional tenants for which the credential + * may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens + * for any tenant in which the application is installed. + */ + std::vector AdditionallyAllowedTenants; }; /** @@ -72,6 +80,7 @@ namespace Azure { namespace Identity { std::string const& clientId, std::string const& clientCertificatePath, std::string const& authorityHost, + std::vector additionallyAllowedTenants, Core::Credentials::TokenCredentialOptions const& options); public: diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp index 494b938bd..f70bd5c6a 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp @@ -17,6 +17,7 @@ #include #include +#include namespace Azure { namespace Identity { namespace _detail { @@ -38,6 +39,13 @@ namespace Azure { namespace Identity { * https://docs.microsoft.com/azure/active-directory/develop/authentication-national-cloud. */ std::string AuthorityHost = _detail::ClientCredentialCore::AadGlobalAuthority; + + /** + * @brief For multi-tenant applications, specifies additional tenants for which the credential + * may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens + * for any tenant in which the application is installed. + */ + std::vector AdditionallyAllowedTenants; }; /** @@ -57,6 +65,7 @@ namespace Azure { namespace Identity { std::string const& clientId, std::string const& clientSecret, std::string const& authorityHost, + std::vector additionallyAllowedTenants, Core::Credentials::TokenCredentialOptions const& options); public: diff --git a/sdk/identity/azure-identity/inc/azure/identity/detail/client_credential_core.hpp b/sdk/identity/azure-identity/inc/azure/identity/detail/client_credential_core.hpp index 0a22b5bc7..d6af9d43c 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/detail/client_credential_core.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/detail/client_credential_core.hpp @@ -9,21 +9,33 @@ #include #include +#include namespace Azure { namespace Identity { namespace _detail { class ClientCredentialCore final { + std::vector m_additionallyAllowedTenants; Core::Url m_authorityHost; std::string m_tenantId; - bool m_isAdfs; public: AZ_IDENTITY_DLLEXPORT static std::string const AadGlobalAuthority; - explicit ClientCredentialCore(std::string tenantId, std::string const& authorityHost); + explicit ClientCredentialCore( + std::string tenantId, + std::string const& authorityHost, + std::vector additionallyAllowedTenants); - Core::Url GetRequestUrl() const; + Core::Url GetRequestUrl(std::string const& tenantId) const; - std::string GetScopesString(decltype(Core::Credentials::TokenRequestContext::Scopes) - const& scopes) const; + std::string GetScopesString( + std::string const& tenantId, + decltype(Core::Credentials::TokenRequestContext::Scopes) const& scopes) const; + + std::string const& GetTenantId() const { return m_tenantId; } + + std::vector const& GetAdditionallyAllowedTenants() const + { + return m_additionallyAllowedTenants; + } }; }}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp index 2434302a2..3e2ce82bc 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace Azure { namespace Identity { namespace _detail { /** @@ -39,26 +40,27 @@ namespace Azure { namespace Identity { namespace _detail { // A test hook that gets invoked before item write lock gets acquired. virtual void OnBeforeItemWriteLock() const {}; + struct CacheKey + { + std::string Scope; + std::string TenantId; + }; + + struct CacheKeyComparator + { + bool operator()(CacheKey const& lhs, CacheKey const& rhs) const + { + return std::tie(lhs.Scope, lhs.TenantId) < std::tie(rhs.Scope, rhs.TenantId); + } + }; + struct CacheValue { Core::Credentials::AccessToken AccessToken; std::shared_timed_mutex ElementMutex; }; - // The current cache Key, std::string Scopes, may later evolve to a struct that contains more - // fields. All that depends on the fields in the TokenRequestContext that are used as - // characteristics that go into the network request that gets the token. - // If tomorrow we add Multi-Tenant Authentication, and the TenantID stops being an immutable - // characteristic of a credential instance, but instead becomes variable depending on the fields - // of the TokenRequestContext that are taken into consideration as network request for the token - // is being sent, it should go into what will form the new CacheKey struct. - // i.e. we want all the variable inputs for obtaining a token to be a part of the key, because - // we want to have the same kind of result. There should be no "hidden variables". - // Otherwise, the cache will stop functioning properly, because the value you'd get from cache - // for a given key will fail to authenticate, but if the cache ends up calling the getNewToken - // callback, you'll authenticate successfully (however the other caller who need to get the - // token for slightly different context will not be as lucky). - mutable std::map> m_cache; + mutable std::map, CacheKeyComparator> m_cache; mutable std::shared_timed_mutex m_cacheMutex; private: @@ -73,7 +75,7 @@ namespace Azure { namespace Identity { namespace _detail { // Gets item from cache, or creates it, puts into cache, and returns. std::shared_ptr GetOrCreateValue( - std::string const& key, + CacheKey const& key, DateTime::duration minimumExpiration) const; public: @@ -85,6 +87,7 @@ namespace Azure { namespace Identity { namespace _detail { * provided, caches it, and returns its value. * * @param scopeString Authentication scopes (or resource) as string. + * @param tenantId TenantId for authentication. * @param minimumExpiration Minimum token lifetime for the cached value to be returned. * @param getNewToken Function to get the new token for the given \p scopeString, in case when * cache does not have it, or if its remaining lifetime is less than \p minimumExpiration. @@ -94,6 +97,7 @@ namespace Azure { namespace Identity { namespace _detail { */ Core::Credentials::AccessToken GetToken( std::string const& scopeString, + std::string const& tenantId, DateTime::duration minimumExpiration, std::function const& getNewToken) const; }; diff --git a/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp index 72c447b48..73d7018ca 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp @@ -12,37 +12,61 @@ #include #include +#include +#include namespace Azure { namespace Identity { + /** + * @brief Options for token authentication. + * + */ + struct EnvironmentCredentialOptions final : public Core::Credentials::TokenCredentialOptions + { + /** + * @brief For multi-tenant applications, specifies additional tenants for which the credential + * may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens + * for any tenant in which the application is installed. + */ + std::vector AdditionallyAllowedTenants; + }; + /** * @brief Environment Credential initializes an Azure credential, based on the system environment * variables being set. * + * @note May read from the following environment variables: + * - `AZURE_TENANT_ID` + * - `AZURE_CLIENT_ID` + * - `AZURE_CLIENT_SECRET` + * - `AZURE_CLIENT_CERTIFICATE_PATH` + * - `AZURE_CLIENT_CERTIFICATE_PASSWORD` + * - `AZURE_CLIENT_SEND_CERTIFICATE_CHAIN` + * - `AZURE_USERNAME` + * - `AZURE_PASSWORD` + * - `AZURE_AUTHORITY_HOST` */ class EnvironmentCredential final : public Core::Credentials::TokenCredential { private: std::unique_ptr m_credentialImpl; + explicit EnvironmentCredential( + Core::Credentials::TokenCredentialOptions const& options, + std::vector const& additionallyAllowedTenants); + public: /** * @brief Constructs an Environment Credential. - * * @param options Options for token retrieval. - * - * @note May read from the following environment variables: - * - AZURE_TENANT_ID - * - AZURE_CLIENT_ID - * - AZURE_CLIENT_SECRET - * - AZURE_CLIENT_CERTIFICATE_PATH - * - AZURE_CLIENT_CERTIFICATE_PASSWORD - * - AZURE_CLIENT_SEND_CERTIFICATE_CHAIN - * - AZURE_USERNAME - * - AZURE_PASSWORD - * - AZURE_AUTHORITY_HOST */ explicit EnvironmentCredential( - Azure::Core::Credentials::TokenCredentialOptions options - = Azure::Core::Credentials::TokenCredentialOptions()); + Core::Credentials::TokenCredentialOptions const& options + = Core::Credentials::TokenCredentialOptions()); + + /** + * @brief Constructs an Environment Credential. + * @param options Options for token retrieval. + */ + explicit EnvironmentCredential(EnvironmentCredentialOptions const& options); /** * @brief Gets an authentication token. diff --git a/sdk/identity/azure-identity/src/azure_cli_credential.cpp b/sdk/identity/azure-identity/src/azure_cli_credential.cpp index 443088770..d4fc308ac 100644 --- a/sdk/identity/azure-identity/src/azure_cli_credential.cpp +++ b/sdk/identity/azure-identity/src/azure_cli_credential.cpp @@ -4,6 +4,7 @@ #include "azure/identity/azure_cli_credential.hpp" #include "private/identity_log.hpp" +#include "private/tenant_id_resolver.hpp" #include "private/token_credential_impl.hpp" #include @@ -47,6 +48,7 @@ using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Identity::AzureCliCredentialOptions; using Azure::Identity::_detail::IdentityLog; +using Azure::Identity::_detail::TenantIdResolver; using Azure::Identity::_detail::TokenCache; using Azure::Identity::_detail::TokenCredentialImpl; @@ -77,11 +79,13 @@ void AzureCliCredential::ThrowIfNotSafeCmdLineInput( } } AzureCliCredential::AzureCliCredential( + Core::Credentials::TokenCredentialOptions const& options, std::string tenantId, DateTime::duration cliProcessTimeout, - Core::Credentials::TokenCredentialOptions const& options) - : TokenCredential("AzureCliCredential"), m_tenantId(std::move(tenantId)), - m_cliProcessTimeout(std::move(cliProcessTimeout)) + std::vector additionallyAllowedTenants) + : TokenCredential("AzureCliCredential"), + m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)), + m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout)) { static_cast(options); @@ -95,15 +99,20 @@ AzureCliCredential::AzureCliCredential( } AzureCliCredential::AzureCliCredential(AzureCliCredentialOptions const& options) - : AzureCliCredential(options.TenantId, options.CliProcessTimeout, options) + : AzureCliCredential( + options, + options.TenantId, + options.CliProcessTimeout, + options.AdditionallyAllowedTenants) { } AzureCliCredential::AzureCliCredential(TokenCredentialOptions const& options) : AzureCliCredential( + options, AzureCliCredentialOptions{}.TenantId, AzureCliCredentialOptions{}.CliProcessTimeout, - options) + AzureCliCredentialOptions{}.AdditionallyAllowedTenants) { } @@ -133,15 +142,17 @@ AccessToken AzureCliCredential::GetToken( Context const& context) const { auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false); + auto const tenantId + = TenantIdResolver::Resolve(m_tenantId, tokenRequestContext, m_additionallyAllowedTenants); - // TokenCache::GetToken() can only use the lambda argument when they are being executed. They are - // not supposed to keep a reference to lambda argument to call it later. Therefore, any capture - // made here will outlive the possible time frame when the lambda might get called. - return m_tokenCache.GetToken(scopes, tokenRequestContext.MinimumExpiration, [&]() { + // TokenCache::GetToken() can only use the lambda argument when they are being executed. They + // are not supposed to keep a reference to lambda argument to call it later. Therefore, any + // capture made here will outlive the possible time frame when the lambda might get called. + return m_tokenCache.GetToken(scopes, tenantId, tokenRequestContext.MinimumExpiration, [&]() { try { auto const azCliResult - = RunShellCommand(GetAzCommand(scopes, m_tenantId), m_cliProcessTimeout, context); + = RunShellCommand(GetAzCommand(scopes, tenantId), m_cliProcessTimeout, context); try { diff --git a/sdk/identity/azure-identity/src/client_certificate_credential.cpp b/sdk/identity/azure-identity/src/client_certificate_credential.cpp index 9ffd31fbd..e8dd7575e 100644 --- a/sdk/identity/azure-identity/src/client_certificate_credential.cpp +++ b/sdk/identity/azure-identity/src/client_certificate_credential.cpp @@ -3,6 +3,7 @@ #include "azure/identity/client_certificate_credential.hpp" +#include "private/tenant_id_resolver.hpp" #include "private/token_credential_impl.hpp" #include @@ -33,6 +34,7 @@ using Azure::Core::Credentials::AuthenticationException; using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Http::HttpMethod; +using Azure::Identity::_detail::TenantIdResolver; using Azure::Identity::_detail::TokenCredentialImpl; namespace { @@ -79,9 +81,10 @@ ClientCertificateCredential::ClientCertificateCredential( std::string const& clientId, std::string const& clientCertificatePath, std::string const& authorityHost, + std::vector additionallyAllowedTenants, TokenCredentialOptions const& options) : TokenCredential("ClientCertificateCredential"), - m_clientCredentialCore(tenantId, authorityHost), + m_clientCredentialCore(tenantId, authorityHost, additionallyAllowedTenants), m_tokenCredentialImpl(std::make_unique(options)), m_requestBody( std::string( @@ -178,6 +181,7 @@ ClientCertificateCredential::ClientCertificateCredential( clientId, clientCertificatePath, options.AuthorityHost, + options.AdditionallyAllowedTenants, options) { } @@ -192,6 +196,7 @@ ClientCertificateCredential::ClientCertificateCredential( clientId, clientCertificatePath, ClientCertificateCredentialOptions{}.AuthorityHost, + ClientCertificateCredentialOptions{}.AdditionallyAllowedTenants, options) { } @@ -202,13 +207,19 @@ AccessToken ClientCertificateCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - auto const scopesStr = m_clientCredentialCore.GetScopesString(tokenRequestContext.Scopes); + auto const tenantId = TenantIdResolver::Resolve( + m_clientCredentialCore.GetTenantId(), + tokenRequestContext, + m_clientCredentialCore.GetAdditionallyAllowedTenants()); + + auto const scopesStr + = m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes); // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() { return m_tokenCredentialImpl->GetToken(context, [&]() { auto body = m_requestBody; if (!scopesStr.empty()) @@ -216,7 +227,7 @@ AccessToken ClientCertificateCredential::GetToken( body += "&scope=" + scopesStr; } - auto const requestUrl = m_clientCredentialCore.GetRequestUrl(); + auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId); std::string assertion = m_tokenHeaderEncoded; { diff --git a/sdk/identity/azure-identity/src/client_credential_core.cpp b/sdk/identity/azure-identity/src/client_credential_core.cpp index 4474e12b1..2124b7446 100644 --- a/sdk/identity/azure-identity/src/client_credential_core.cpp +++ b/sdk/identity/azure-identity/src/client_credential_core.cpp @@ -3,6 +3,7 @@ #include "azure/identity/detail/client_credential_core.hpp" +#include "private/tenant_id_resolver.hpp" #include "private/token_credential_impl.hpp" #include @@ -11,29 +12,35 @@ using Azure::Identity::_detail::ClientCredentialCore; using Azure::Core::Url; using Azure::Core::Credentials::TokenRequestContext; +using Azure::Identity::_detail::TenantIdResolver; using Azure::Identity::_detail::TokenCredentialImpl; decltype(ClientCredentialCore::AadGlobalAuthority) ClientCredentialCore::AadGlobalAuthority = "https://login.microsoftonline.com/"; -ClientCredentialCore::ClientCredentialCore(std::string tenantId, std::string const& authorityHost) - : m_authorityHost(Url(authorityHost)), m_tenantId(std::move(tenantId)) +ClientCredentialCore::ClientCredentialCore( + std::string tenantId, + std::string const& authorityHost, + std::vector additionallyAllowedTenants) + : m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)), + m_authorityHost(Url(authorityHost)), m_tenantId(std::move(tenantId)) { - // ADFS is the Active Directory Federation Service, a tenant ID that is used in Azure Stack. - m_isAdfs = m_tenantId == "adfs"; } -Url ClientCredentialCore::GetRequestUrl() const +Url ClientCredentialCore::GetRequestUrl(std::string const& tenantId) const { auto requestUrl = m_authorityHost; - requestUrl.AppendPath(m_tenantId); - requestUrl.AppendPath(m_isAdfs ? "oauth2/token" : "oauth2/v2.0/token"); + requestUrl.AppendPath(tenantId); + requestUrl.AppendPath(TenantIdResolver::IsAdfs(tenantId) ? "oauth2/token" : "oauth2/v2.0/token"); return requestUrl; } -std::string ClientCredentialCore::GetScopesString(decltype(TokenRequestContext::Scopes) - const& scopes) const +std::string ClientCredentialCore::GetScopesString( + std::string const& tenantId, + decltype(TokenRequestContext::Scopes) const& scopes) const { - return scopes.empty() ? std::string() : TokenCredentialImpl::FormatScopes(scopes, m_isAdfs); + return scopes.empty() + ? std::string() + : TokenCredentialImpl::FormatScopes(scopes, TenantIdResolver::IsAdfs(tenantId)); } diff --git a/sdk/identity/azure-identity/src/client_secret_credential.cpp b/sdk/identity/azure-identity/src/client_secret_credential.cpp index c0e8355f6..1e8dadddc 100644 --- a/sdk/identity/azure-identity/src/client_secret_credential.cpp +++ b/sdk/identity/azure-identity/src/client_secret_credential.cpp @@ -3,6 +3,7 @@ #include "azure/identity/client_secret_credential.hpp" +#include "private/tenant_id_resolver.hpp" #include "private/token_credential_impl.hpp" using Azure::Identity::ClientSecretCredential; @@ -13,6 +14,7 @@ using Azure::Core::Credentials::AccessToken; using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Http::HttpMethod; +using Azure::Identity::_detail::TenantIdResolver; using Azure::Identity::_detail::TokenCredentialImpl; ClientSecretCredential::ClientSecretCredential( @@ -20,8 +22,10 @@ ClientSecretCredential::ClientSecretCredential( std::string const& clientId, std::string const& clientSecret, std::string const& authorityHost, + std::vector additionallyAllowedTenants, TokenCredentialOptions const& options) - : TokenCredential("ClientSecretCredential"), m_clientCredentialCore(tenantId, authorityHost), + : TokenCredential("ClientSecretCredential"), + m_clientCredentialCore(tenantId, authorityHost, additionallyAllowedTenants), m_tokenCredentialImpl(std::make_unique(options)), m_requestBody( std::string("grant_type=client_credentials&client_id=") + Url::Encode(clientId) @@ -34,7 +38,13 @@ ClientSecretCredential::ClientSecretCredential( std::string const& clientId, std::string const& clientSecret, ClientSecretCredentialOptions const& options) - : ClientSecretCredential(tenantId, clientId, clientSecret, options.AuthorityHost, options) + : ClientSecretCredential( + tenantId, + clientId, + clientSecret, + options.AuthorityHost, + options.AdditionallyAllowedTenants, + options) { } @@ -48,6 +58,7 @@ ClientSecretCredential::ClientSecretCredential( clientId, clientSecret, ClientSecretCredentialOptions{}.AuthorityHost, + ClientSecretCredentialOptions{}.AdditionallyAllowedTenants, options) { } @@ -58,13 +69,20 @@ AccessToken ClientSecretCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const { - auto const scopesStr = m_clientCredentialCore.GetScopesString(tokenRequestContext.Scopes); + + auto const tenantId = TenantIdResolver::Resolve( + m_clientCredentialCore.GetTenantId(), + tokenRequestContext, + m_clientCredentialCore.GetAdditionallyAllowedTenants()); + + auto const scopesStr + = m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes); // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() { return m_tokenCredentialImpl->GetToken(context, [&]() { auto body = m_requestBody; @@ -73,7 +91,7 @@ AccessToken ClientSecretCredential::GetToken( body += "&scope=" + scopesStr; } - auto const requestUrl = m_clientCredentialCore.GetRequestUrl(); + auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId); auto request = std::make_unique(HttpMethod::Post, requestUrl, body); diff --git a/sdk/identity/azure-identity/src/environment_credential.cpp b/sdk/identity/azure-identity/src/environment_credential.cpp index 80dc948ff..5c1fbcf49 100644 --- a/sdk/identity/azure-identity/src/environment_credential.cpp +++ b/sdk/identity/azure-identity/src/environment_credential.cpp @@ -14,6 +14,7 @@ #include using Azure::Identity::EnvironmentCredential; +using Azure::Identity::EnvironmentCredentialOptions; using Azure::Core::Context; using Azure::Core::_internal::Environment; @@ -36,7 +37,9 @@ void PrintCredentialCreationLogMessage( char const* credThatGetsCreated); } // namespace -EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options) +EnvironmentCredential::EnvironmentCredential( + TokenCredentialOptions const& options, + std::vector const& additionallyAllowedTenants) : TokenCredential("EnvironmentCredential") { auto tenantId = Environment::GetVariable(AzureTenantIdEnvVarName); @@ -49,77 +52,48 @@ EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options) if (!tenantId.empty() && !clientId.empty()) { + std::vector> envVarsToParams + = {{AzureTenantIdEnvVarName, "tenantId"}, {AzureClientIdEnvVarName, "clientId"}}; + if (!clientSecret.empty()) { + envVarsToParams.push_back({AzureClientSecretEnvVarName, "clientSecret"}); + + ClientSecretCredentialOptions clientSecretCredentialOptions; + static_cast(clientSecretCredentialOptions) = options; + clientSecretCredentialOptions.AdditionallyAllowedTenants = additionallyAllowedTenants; + if (!authority.empty()) { - PrintCredentialCreationLogMessage( - GetCredentialName(), - { - {AzureTenantIdEnvVarName, "tenantId"}, - {AzureClientIdEnvVarName, "clientId"}, - {AzureClientSecretEnvVarName, "clientSecret"}, - {AzureAuthorityHostEnvVarName, "authorityHost"}, - }, - "ClientSecretCredential"); - - ClientSecretCredentialOptions clientSecretCredentialOptions; - static_cast(clientSecretCredentialOptions) = options; + envVarsToParams.push_back({AzureAuthorityHostEnvVarName, "authorityHost"}); clientSecretCredentialOptions.AuthorityHost = authority; - - m_credentialImpl.reset(new ClientSecretCredential( - tenantId, clientId, clientSecret, clientSecretCredentialOptions)); } - else - { - PrintCredentialCreationLogMessage( - GetCredentialName(), - { - {AzureTenantIdEnvVarName, "tenantId"}, - {AzureClientIdEnvVarName, "clientId"}, - {AzureClientSecretEnvVarName, "clientSecret"}, - }, - "ClientSecretCredential"); - m_credentialImpl.reset( - new ClientSecretCredential(tenantId, clientId, clientSecret, options)); - } + PrintCredentialCreationLogMessage( + GetCredentialName(), envVarsToParams, "ClientSecretCredential"); + + m_credentialImpl.reset(new ClientSecretCredential( + tenantId, clientId, clientSecret, clientSecretCredentialOptions)); } else if (!clientCertificatePath.empty()) { + envVarsToParams.push_back({AzureClientCertificatePathEnvVarName, "clientCertificatePath"}); + + ClientCertificateCredentialOptions clientCertificateCredentialOptions; + static_cast(clientCertificateCredentialOptions) = options; + clientCertificateCredentialOptions.AdditionallyAllowedTenants = additionallyAllowedTenants; + if (!authority.empty()) { - PrintCredentialCreationLogMessage( - GetCredentialName(), - { - {AzureTenantIdEnvVarName, "tenantId"}, - {AzureClientIdEnvVarName, "clientId"}, - {AzureClientCertificatePathEnvVarName, "clientCertificatePath"}, - {AzureAuthorityHostEnvVarName, "authorityHost"}, - }, - "ClientCertificateCredential"); - - ClientCertificateCredentialOptions clientCertificateCredentialOptions; - static_cast(clientCertificateCredentialOptions) = options; + envVarsToParams.push_back({AzureAuthorityHostEnvVarName, "authorityHost"}); clientCertificateCredentialOptions.AuthorityHost = authority; - - m_credentialImpl.reset(new ClientCertificateCredential( - tenantId, clientId, clientCertificatePath, clientCertificateCredentialOptions)); } - else - { - PrintCredentialCreationLogMessage( - GetCredentialName(), - { - {AzureTenantIdEnvVarName, "tenantId"}, - {AzureClientIdEnvVarName, "clientId"}, - {AzureClientCertificatePathEnvVarName, "clientCertificatePath"}, - }, - "ClientCertificateCredential"); - m_credentialImpl.reset( - new ClientCertificateCredential(tenantId, clientId, clientCertificatePath, options)); - } + PrintCredentialCreationLogMessage( + GetCredentialName(), envVarsToParams, "ClientCertificateCredential"); + + m_credentialImpl.reset(new ClientCertificateCredential( + tenantId, clientId, clientCertificatePath, clientCertificateCredentialOptions)); } } @@ -156,6 +130,16 @@ EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options) } } +EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions const& options) + : EnvironmentCredential(options, {}) +{ +} + +EnvironmentCredential::EnvironmentCredential(EnvironmentCredentialOptions const& options) + : EnvironmentCredential(options, options.AdditionallyAllowedTenants) +{ +} + AccessToken EnvironmentCredential::GetToken( TokenRequestContext const& tokenRequestContext, Context const& context) const diff --git a/sdk/identity/azure-identity/src/managed_identity_source.cpp b/sdk/identity/azure-identity/src/managed_identity_source.cpp index e3001b9c6..bec7309bf 100644 --- a/sdk/identity/azure-identity/src/managed_identity_source.cpp +++ b/sdk/identity/azure-identity/src/managed_identity_source.cpp @@ -136,7 +136,7 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() { return TokenCredentialImpl::GetToken(context, [&]() { auto request = std::make_unique(m_request); @@ -218,7 +218,7 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() { return TokenCredentialImpl::GetToken(context, [&]() { using Azure::Core::Url; using Azure::Core::Http::HttpMethod; @@ -317,7 +317,7 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() { return TokenCredentialImpl::GetToken( context, createRequest, @@ -417,7 +417,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( // when they are being executed. They are not supposed to keep a reference to lambda argument to // call it later. Therefore, any capture made here will outlive the possible time frame when the // lambda might get called. - return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() { + return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() { return TokenCredentialImpl::GetToken(context, [&]() { auto request = std::make_unique(m_request); diff --git a/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp b/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp new file mode 100644 index 000000000..84c590238 --- /dev/null +++ b/sdk/identity/azure-identity/src/private/tenant_id_resolver.hpp @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include +#include + +namespace Azure { namespace Identity { namespace _detail { + /** + * @brief Implements an access token cache. + * + */ + class TenantIdResolver final { + TenantIdResolver() = delete; + ~TenantIdResolver() = delete; + + public: + static std::string Resolve( + std::string const& explicitTenantId, + Core::Credentials::TokenRequestContext const& tokenRequestContext, + std::vector const& additionallyAllowedTenants); + + // ADFS is the Active Directory Federation Service, a tenant ID that is used in Azure Stack. + static bool IsAdfs(std::string const& tenantId); + }; +}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/private/token_cache.hpp b/sdk/identity/azure-identity/src/private/token_cache.hpp deleted file mode 100644 index fa88a2b12..000000000 --- a/sdk/identity/azure-identity/src/private/token_cache.hpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// SPDX-License-Identifier: MIT - -/** - * @file - * @brief Token cache. - */ - -#pragma once - -#include -#include - -#include -#include - -namespace Azure { namespace Identity { namespace _detail { - /** - * @brief Implements an access token cache. - * - */ - class TokenCache final { - TokenCache() = delete; - ~TokenCache() = delete; - - public: - /** - * @brief Attempts to get token from cache, and if not found, gets the token using the function - * provided, caches it, and returns its value. - * - * @param tenantId Azure Tenant ID. - * @param clientId Azure Client ID. - * @param authorityHost Authentication authority URL. - * @param scopes Authentication scopes. - * - * @return Authentication token. - * - * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. - */ - static Core::Credentials::AccessToken GetToken( - std::string const& tenantId, - std::string const& clientId, - std::string const& authorityHost, - std::string const& scopes, - DateTime::duration minimumExpiration, - std::function const& getNewToken); - - /** - * @brief Provides access to internal aspects of the cache as a test hook. - * - */ - class Internals; - -#if defined(TESTING_BUILD) - /** - * @brief Clears token cache. Intended to only be used in tests. - * - */ - static void Clear(); -#endif - }; -}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/tenant_id_resolver.cpp b/sdk/identity/azure-identity/src/tenant_id_resolver.cpp new file mode 100644 index 000000000..5ea882ef0 --- /dev/null +++ b/sdk/identity/azure-identity/src/tenant_id_resolver.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "private/tenant_id_resolver.hpp" + +#include +#include + +using Azure::Identity::_detail::TenantIdResolver; + +using Azure::Core::_internal::Environment; +using Azure::Core::_internal::StringExtensions; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenRequestContext; + +namespace { +bool IsMultitenantAuthDisabled() +{ + auto const envVar = Environment::GetVariable("AZURE_IDENTITY_DISABLE_MULTITENANTAUTH"); + return envVar == "1" || StringExtensions::LocaleInvariantCaseInsensitiveEqual(envVar, "true"); +} +} // namespace + +std::string TenantIdResolver::Resolve( + std::string const& explicitTenantId, + TokenRequestContext const& tokenRequestContext, + std::vector const& additionallyAllowedTenants) +{ + auto const& requestedTenantId = tokenRequestContext.TenantId; + + if (requestedTenantId.empty() + || StringExtensions::LocaleInvariantCaseInsensitiveEqual(requestedTenantId, explicitTenantId) + || IsAdfs(explicitTenantId) || IsMultitenantAuthDisabled()) + { + return explicitTenantId; + } + + for (auto const& allowedTenantId : additionallyAllowedTenants) + { + if (allowedTenantId == "*" + || StringExtensions::LocaleInvariantCaseInsensitiveEqual( + allowedTenantId, requestedTenantId)) + { + return requestedTenantId; + } + } + + throw AuthenticationException( + "The current credential is not configured to acquire tokens for tenant '" + requestedTenantId + + "'. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on " + "the credential options, or add \"*\" to AdditionallyAllowedTenants to allow acquiring " + "tokens for any tenant."); +} + +bool TenantIdResolver::IsAdfs(std::string const& tenantId) +{ + return StringExtensions::LocaleInvariantCaseInsensitiveEqual(tenantId, "adfs"); +} diff --git a/sdk/identity/azure-identity/src/token_cache.cpp b/sdk/identity/azure-identity/src/token_cache.cpp index a69f87773..a1dd5c291 100644 --- a/sdk/identity/azure-identity/src/token_cache.cpp +++ b/sdk/identity/azure-identity/src/token_cache.cpp @@ -26,7 +26,7 @@ template bool ShouldCleanUpCacheFromExpiredItems(T cacheSize); } std::shared_ptr TokenCache::GetOrCreateValue( - std::string const& key, + CacheKey const& key, DateTime::duration minimumExpiration) const { { @@ -86,10 +86,11 @@ std::shared_ptr TokenCache::GetOrCreateValue( AccessToken TokenCache::GetToken( std::string const& scopeString, + std::string const& tenantId, DateTime::duration minimumExpiration, std::function const& getNewToken) const { - auto const item = GetOrCreateValue(scopeString, minimumExpiration); + auto const item = GetOrCreateValue({scopeString, tenantId}, minimumExpiration); { std::shared_lock itemReadLock(item->ElementMutex); diff --git a/sdk/identity/azure-identity/test/ut/CMakeLists.txt b/sdk/identity/azure-identity/test/ut/CMakeLists.txt index 2c560e245..9c10e2331 100644 --- a/sdk/identity/azure-identity/test/ut/CMakeLists.txt +++ b/sdk/identity/azure-identity/test/ut/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable ( macro_guard_test.cpp managed_identity_credential_test.cpp simplified_header_test.cpp + tenant_id_resolver_test.cpp token_cache_test.cpp token_credential_impl_test.cpp token_credential_test.cpp diff --git a/sdk/identity/azure-identity/test/ut/tenant_id_resolver_test.cpp b/sdk/identity/azure-identity/test/ut/tenant_id_resolver_test.cpp new file mode 100644 index 000000000..ab0f839b5 --- /dev/null +++ b/sdk/identity/azure-identity/test/ut/tenant_id_resolver_test.cpp @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "private/tenant_id_resolver.hpp" + +#include "credential_test_helper.hpp" + +#include + +using Azure::Identity::_detail::TenantIdResolver; + +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenRequestContext; +using Azure::Identity::Test::_detail::CredentialTestHelper; + +TEST(TenantIdResolver, RequestedTenantIdEmpty) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", ""}, // Default, i.e. NOT disabled + }); + + auto const tenantId = TenantIdResolver::Resolve("aA", {}, {}); + + EXPECT_EQ(tenantId, "aA"); +} + +TEST(TenantIdResolver, RequestedTenantIdEqualsExplicitTenantId) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "0"}, // Default, i.e. NOT disabled + }); + + TokenRequestContext trc; + trc.TenantId = "Aa"; + + auto const tenantId = TenantIdResolver::Resolve("aA", trc, {}); + + EXPECT_EQ(tenantId, "aA"); +} + +TEST(TenantIdResolver, Adfs) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "false"}, // Default, i.e. NOT disabled + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + auto const tenantId = TenantIdResolver::Resolve("aDfS", trc, {}); + + EXPECT_EQ(tenantId, "aDfS"); +} + +TEST(TenantIdResolver, Disabled1) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "1"}, // Should be DISABLED + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + auto const tenantId = TenantIdResolver::Resolve("aA", trc, {}); + + EXPECT_EQ(tenantId, "aA"); +} + +TEST(TenantIdResolver, DisabledTrue) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "tRuE"}, // Should be DISABLED + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + auto const tenantId = TenantIdResolver::Resolve("aA", trc, {}); + + EXPECT_EQ(tenantId, "aA"); +} + +TEST(TenantIdResolver, Wildcard) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "2"}, // Not a value that should be recognized + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + auto const tenantId = TenantIdResolver::Resolve("aA", trc, {"cC", "*", "dD"}); + + EXPECT_EQ(tenantId, "bB"); +} + +TEST(TenantIdResolver, Match) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "T"}, // Not a value that should be recognized + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + auto const tenantId = TenantIdResolver::Resolve("bA", trc, {"cC", "Bb", "dD"}); + + EXPECT_EQ(tenantId, "bB"); +} + +TEST(TenantIdResolver, NoMatch) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "yes"}, // Not a value that should be recognized + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + EXPECT_THROW( + static_cast(TenantIdResolver::Resolve("aA", trc, {"cC", "dD"})), + AuthenticationException); +} + +TEST(TenantIdResolver, NoMatchEmpty) +{ + CredentialTestHelper::EnvironmentOverride const env(std::map{ + {"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "on"}, // Not a value that should be recognized + }); + + TokenRequestContext trc; + trc.TenantId = "bB"; + + EXPECT_THROW( + static_cast(TenantIdResolver::Resolve("aA", trc, {})), AuthenticationException); +} diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp index d3b6d01b7..4731db56c 100644 --- a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -53,7 +53,7 @@ TEST(TokenCache, GetReuseRefresh) auto const Yesterday = Tomorrow - 48h; { - auto const token1 = tokenCache.GetToken("A", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -65,7 +65,7 @@ TEST(TokenCache, GetReuseRefresh) EXPECT_EQ(token1.ExpiresOn, Tomorrow); EXPECT_EQ(token1.Token, "T1"); - auto const token2 = tokenCache.GetToken("A", 2min, [=]() { + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); AccessToken result; result.Token = "T2"; @@ -80,9 +80,9 @@ TEST(TokenCache, GetReuseRefresh) } { - tokenCache.m_cache["A"]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday; - auto const token = tokenCache.GetToken("A", 2min, [=]() { + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T3"; result.ExpiresOn = Tomorrow + 1min; @@ -106,7 +106,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) tokenCache.m_onBeforeCacheWriteLock = [&]() { tokenCache.m_onBeforeCacheWriteLock = nullptr; - static_cast(tokenCache.GetToken("A", 2min, [=]() { + static_cast(tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -114,7 +114,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) })); }; - auto const token = tokenCache.GetToken("A", 2min, [=]() { + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " "acquiring cache write lock"); AccessToken result; @@ -141,12 +141,12 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) tokenCache.m_onBeforeItemWriteLock = [&]() { tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache["A"]; + auto const item = tokenCache.m_cache[{"A", {}}]; item->AccessToken.Token = "T1"; item->AccessToken.ExpiresOn = Tomorrow; }; - auto const token = tokenCache.GetToken("A", 2min, [=]() { + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " "acquiring item write lock"); AccessToken result; @@ -167,12 +167,12 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) tokenCache.m_onBeforeItemWriteLock = [&]() { tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache["A"]; + auto const item = tokenCache.m_cache[{"A", {}}]; item->AccessToken.Token = "T3"; item->AccessToken.ExpiresOn = Yesterday; }; - auto const token = tokenCache.GetToken("A", 2min, [=]() { + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T4"; result.ExpiresOn = Tomorrow + 3min; @@ -199,7 +199,7 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 1; i <= 35; ++i) { auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -214,14 +214,14 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; } // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. for (auto i = 36; i <= 55; ++i) { auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -235,11 +235,11 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); } // One more addition to the cache and cleanup for the expired ones will get triggered. - static_cast(tokenCache.GetToken("56", 2min, [=]() { + static_cast(tokenCache.GetToken("56", {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -253,14 +253,14 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); } // Let's expire items from 21 all the way up to 56. for (auto i = 21; i <= 56; ++i) { auto const n = std::to_string(i); - tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; } // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to @@ -268,7 +268,7 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 2; i <= 3; ++i) { auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, 2min, [=]() { + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow; @@ -284,19 +284,19 @@ TEST(TokenCache, ExpiredCleanup) // Out of 4 locked, two are expired, so they should get cleared under normal circumstances, but // this time they will remain in the cache. std::shared_lock readLockForUnexpired( - tokenCache.m_cache["2"]->ElementMutex); + tokenCache.m_cache[{"2", {}}]->ElementMutex); std::shared_lock readLockForExpired( - tokenCache.m_cache["54"]->ElementMutex); + tokenCache.m_cache[{"54", {}}]->ElementMutex); std::unique_lock writeLockForUnexpired( - tokenCache.m_cache["3"]->ElementMutex); + tokenCache.m_cache[{"3", {}}]->ElementMutex); std::unique_lock writeLockForExpired( - tokenCache.m_cache["55"]->ElementMutex); + tokenCache.m_cache[{"55", {}}]->ElementMutex); // Count is at 55. Inserting the 56th element, and it will trigger cleanup. - static_cast(tokenCache.GetToken("1", 2min, [=]() { + static_cast(tokenCache.GetToken("1", {}, 2min, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow; @@ -309,17 +309,17 @@ TEST(TokenCache, ExpiredCleanup) for (auto i = 1; i <= 20; ++i) { auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); } - EXPECT_NE(tokenCache.m_cache.find("54"), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end()); - EXPECT_NE(tokenCache.m_cache.find("55"), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end()); for (auto i = 21; i <= 53; ++i) { auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); } } @@ -331,7 +331,7 @@ TEST(TokenCache, MinimumExpiration) DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const token1 = tokenCache.GetToken("A", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -343,7 +343,7 @@ TEST(TokenCache, MinimumExpiration) EXPECT_EQ(token1.ExpiresOn, Tomorrow); EXPECT_EQ(token1.Token, "T1"); - auto const token2 = tokenCache.GetToken("A", 24h, [=]() { + auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() { AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow + 1h; @@ -364,7 +364,7 @@ TEST(TokenCache, MultithreadedAccess) DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const token1 = tokenCache.GetToken("A", 2min, [=]() { + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; @@ -377,14 +377,15 @@ TEST(TokenCache, MultithreadedAccess) EXPECT_EQ(token1.Token, "T1"); { - std::shared_lock itemReadLock(tokenCache.m_cache["A"]->ElementMutex); + std::shared_lock itemReadLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); { std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); - // Parallel threads read both the container and the item we're accessing, and we can access it - // in parallel as well. - auto const token2 = tokenCache.GetToken("A", 2min, [=]() { + // Parallel threads read both the container and the item we're accessing, and we can + // access it in parallel as well. + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); AccessToken result; result.Token = "T2"; @@ -400,7 +401,7 @@ TEST(TokenCache, MultithreadedAccess) // The cache is unlocked, but one item is being read in a parallel thread, which does not // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("B", 2min, [=]() { + auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() { AccessToken result; result.Token = "T3"; result.ExpiresOn = Tomorrow + 2h; @@ -414,11 +415,12 @@ TEST(TokenCache, MultithreadedAccess) } { - std::unique_lock itemWriteLock(tokenCache.m_cache["A"]->ElementMutex); + std::unique_lock itemWriteLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); // The cache is unlocked, but one item is being written in a parallel thread, which does not // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("C", 2min, [=]() { + auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() { AccessToken result; result.Token = "T4"; result.ExpiresOn = Tomorrow + 3h; @@ -514,7 +516,8 @@ TEST(TokenCache, PerCredInstance) auto const tokenB = credB.GetToken(getCached, {}); EXPECT_EQ( tokenB.Token, - "SecretB2"); // if token cache was shared between instances, the value would be "SecretA1" + "SecretB2"); // if token cache was shared between instances, the value would be + // "SecretA1" } { @@ -541,3 +544,128 @@ TEST(TokenCache, PerCredInstance) EXPECT_EQ(tokenA6.Token, "SecretA4"); } } + +TEST(TokenCache, TenantId) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + AccessToken result; + result.Token = "AX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + AccessToken result; + result.Token = "BX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + AccessToken result; + result.Token = "AY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + AccessToken result; + result.Token = "BY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } +} diff --git a/sdk/identity/azure-identity/vcpkg/vcpkg.json b/sdk/identity/azure-identity/vcpkg/vcpkg.json index dc380c61a..6598bd81e 100644 --- a/sdk/identity/azure-identity/vcpkg/vcpkg.json +++ b/sdk/identity/azure-identity/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.8.0" + "version>=": "1.9.0-beta.1" }, "openssl", { diff --git a/sdk/keyvault/azure-security-keyvault-administration/CHANGELOG.md b/sdk/keyvault/azure-security-keyvault-administration/CHANGELOG.md index 060b91294..79b49e0fb 100644 --- a/sdk/keyvault/azure-security-keyvault-administration/CHANGELOG.md +++ b/sdk/keyvault/azure-security-keyvault-administration/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added support for challenge-based and multi-tenant authentication. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/keyvault/azure-security-keyvault-administration/src/settings_client.cpp b/sdk/keyvault/azure-security-keyvault-administration/src/settings_client.cpp index 16e3ec450..de106ec4e 100644 --- a/sdk/keyvault/azure-security-keyvault-administration/src/settings_client.cpp +++ b/sdk/keyvault/azure-security-keyvault-administration/src/settings_client.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -51,7 +52,8 @@ SettingsClient::SettingsClient( tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)}; perRetrypolicies.emplace_back( - std::make_unique(credential, std::move(tokenContext))); + std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>( + credential, std::move(tokenContext))); } std::vector> perCallpolicies; diff --git a/sdk/keyvault/azure-security-keyvault-administration/vcpkg/vcpkg.json b/sdk/keyvault/azure-security-keyvault-administration/vcpkg/vcpkg.json index 38f4d1e51..543718f89 100644 --- a/sdk/keyvault/azure-security-keyvault-administration/vcpkg/vcpkg.json +++ b/sdk/keyvault/azure-security-keyvault-administration/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.7.2" + "version>=": "1.9.0-beta.1" }, { "name": "vcpkg-cmake", diff --git a/sdk/keyvault/azure-security-keyvault-certificates/CHANGELOG.md b/sdk/keyvault/azure-security-keyvault-certificates/CHANGELOG.md index db741d299..586ad8ced 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/CHANGELOG.md +++ b/sdk/keyvault/azure-security-keyvault-certificates/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added support for challenge-based and multi-tenant authentication. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/keyvault/azure-security-keyvault-certificates/src/certificate_client.cpp b/sdk/keyvault/azure-security-keyvault-certificates/src/certificate_client.cpp index 85aa863db..58ad90b02 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/src/certificate_client.cpp +++ b/sdk/keyvault/azure-security-keyvault-certificates/src/certificate_client.cpp @@ -3,6 +3,7 @@ #include "azure/keyvault/certificates/certificate_client.hpp" +#include "azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp" #include "azure/keyvault/shared/keyvault_shared.hpp" #include "private/certificate_constants.hpp" #include "private/certificate_serializers.hpp" @@ -76,7 +77,8 @@ CertificateClient::CertificateClient( tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)}; perRetrypolicies.emplace_back( - std::make_unique(credential, std::move(tokenContext))); + std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>( + credential, std::move(tokenContext))); } std::vector> perCallpolicies; diff --git a/sdk/keyvault/azure-security-keyvault-certificates/vcpkg/vcpkg.json b/sdk/keyvault/azure-security-keyvault-certificates/vcpkg/vcpkg.json index 73dc87af8..f2a641e49 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/vcpkg/vcpkg.json +++ b/sdk/keyvault/azure-security-keyvault-certificates/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.5.0" + "version>=": "1.9.0-beta.1" }, { "name": "vcpkg-cmake", diff --git a/sdk/keyvault/azure-security-keyvault-keys/CHANGELOG.md b/sdk/keyvault/azure-security-keyvault-keys/CHANGELOG.md index 7c2be54c1..cf55e1727 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/CHANGELOG.md +++ b/sdk/keyvault/azure-security-keyvault-keys/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added support for challenge-based and multi-tenant authentication. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/keyvault/azure-security-keyvault-keys/src/cryptography/cryptography_client.cpp b/sdk/keyvault/azure-security-keyvault-keys/src/cryptography/cryptography_client.cpp index 65cef9a6b..1df12afca 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/src/cryptography/cryptography_client.cpp +++ b/sdk/keyvault/azure-security-keyvault-keys/src/cryptography/cryptography_client.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include "azure/keyvault/keys/cryptography/cryptography_client.hpp" @@ -105,7 +106,8 @@ CryptographyClient::CryptographyClient( tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_keyId)}; perRetrypolicies.emplace_back( - std::make_unique(credential, tokenContext)); + std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>( + credential, tokenContext)); } std::vector> perCallpolicies; diff --git a/sdk/keyvault/azure-security-keyvault-keys/src/key_client.cpp b/sdk/keyvault/azure-security-keyvault-keys/src/key_client.cpp index f1af0dc03..0dae980d1 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/src/key_client.cpp +++ b/sdk/keyvault/azure-security-keyvault-keys/src/key_client.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "azure/keyvault/keys/key_client.hpp" @@ -75,7 +76,8 @@ KeyClient::KeyClient( tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)}; perRetrypolicies.emplace_back( - std::make_unique(credential, std::move(tokenContext))); + std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>( + credential, std::move(tokenContext))); } std::vector> perCallpolicies; diff --git a/sdk/keyvault/azure-security-keyvault-keys/vcpkg/vcpkg.json b/sdk/keyvault/azure-security-keyvault-keys/vcpkg/vcpkg.json index ef96e5e75..de131d2d7 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/vcpkg/vcpkg.json +++ b/sdk/keyvault/azure-security-keyvault-keys/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.5.0" + "version>=": "1.9.0-beta.1" }, { "name": "vcpkg-cmake", diff --git a/sdk/keyvault/azure-security-keyvault-secrets/CHANGELOG.md b/sdk/keyvault/azure-security-keyvault-secrets/CHANGELOG.md index 67cabd93a..d445db0d7 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/CHANGELOG.md +++ b/sdk/keyvault/azure-security-keyvault-secrets/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added support for challenge-based and multi-tenant authentication. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/keyvault/azure-security-keyvault-secrets/src/secret_client.cpp b/sdk/keyvault/azure-security-keyvault-secrets/src/secret_client.cpp index 8379cc940..24f485c9d 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/src/secret_client.cpp +++ b/sdk/keyvault/azure-security-keyvault-secrets/src/secret_client.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -72,7 +73,8 @@ SecretClient::SecretClient( tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(url)}; perRetrypolicies.emplace_back( - std::make_unique(credential, tokenContext)); + std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>( + credential, tokenContext)); } std::vector> perCallpolicies; diff --git a/sdk/keyvault/azure-security-keyvault-secrets/test/ut/CMakeLists.txt b/sdk/keyvault/azure-security-keyvault-secrets/test/ut/CMakeLists.txt index 833342f8b..ad36c53f4 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/test/ut/CMakeLists.txt +++ b/sdk/keyvault/azure-security-keyvault-secrets/test/ut/CMakeLists.txt @@ -16,6 +16,7 @@ SetUpTestProxy("sdk/keyvault") add_executable ( azure-security-keyvault-secrets-test + challenge_based_authentication_policy_test.cpp macro_guard.cpp secret_client_test.cpp secret_get_client_deserialize_test.hpp @@ -38,7 +39,12 @@ endif() target_link_libraries(azure-security-keyvault-secrets-test PRIVATE azure-security-keyvault-secrets azure-identity azure-core-test-fw gtest gtest_main gmock) # Adding private headers so we can test the private APIs with no relative paths include. -target_include_directories (azure-security-keyvault-secrets-test PRIVATE $) +target_include_directories( + azure-security-keyvault-secrets-test + PRIVATE + $ + $ +) # gtest_add_tests will scan the test from azure-core-test and call add_test # for each test to ctest. This enables `ctest -r` to run specific tests directly. diff --git a/sdk/keyvault/azure-security-keyvault-secrets/test/ut/challenge_based_authentication_policy_test.cpp b/sdk/keyvault/azure-security-keyvault-secrets/test/ut/challenge_based_authentication_policy_test.cpp new file mode 100644 index 000000000..80f1dbfa2 --- /dev/null +++ b/sdk/keyvault/azure-security-keyvault-secrets/test/ut/challenge_based_authentication_policy_test.cpp @@ -0,0 +1,1439 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp" + +#include "azure/keyvault/shared/keyvault_shared.hpp" +#include +#include + +#include +#include + +#include + +// cspell:ignore Fvault Ftest + +using Azure::Security::KeyVault::_internal::KeyVaultChallengeBasedAuthenticationPolicy; + +using Azure::Core::CaseInsensitiveMap; +using Azure::Core::Context; +using Azure::Core::Url; +using Azure::Core::_internal::ClientOptions; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredential; +using Azure::Core::Http::HttpStatusCode; +using Azure::Core::Http::HttpTransport; +using Azure::Core::Http::RawResponse; +using Azure::Core::Http::Request; +using Azure::Core::Http::_internal::HttpPipeline; +using Azure::Identity::ClientSecretCredentialOptions; + +namespace { +class TestRequest final { +public: + Azure::Core::Url Url; + CaseInsensitiveMap Headers; + std::string Body; + + explicit TestRequest(Request& request) : Url(request.GetUrl()), Headers(request.GetHeaders()) + { + auto const bodyStreamPtr = request.GetBodyStream(); + if (bodyStreamPtr != nullptr) + { + auto const uint8Vector = bodyStreamPtr->ReadToEnd(); + auto const charPtr = reinterpret_cast(uint8Vector.data()); + Body = std::string(charPtr, charPtr + uint8Vector.size()); + } + } +}; + +class TestResponse final { +private: + HttpStatusCode m_statusCode; + std::shared_ptr m_body; + CaseInsensitiveMap const m_headers; + +public: + TestResponse(HttpStatusCode statusCode, std::string body, CaseInsensitiveMap headers) + : m_statusCode(std::move(statusCode)), m_body(std::make_shared(std::move(body))), + m_headers(std::move(headers)) + { + } + + std::unique_ptr CreateRawResponse() const + { + using Azure::Core::IO::MemoryBodyStream; + + auto response = std::make_unique(1, 1, m_statusCode, "TestReasonPhrase"); + + for (auto const& header : m_headers) + { + response->SetHeader(header.first, header.second); + } + + auto const bodyPtr = reinterpret_cast(m_body->data()); + response->SetBodyStream(std::make_unique(bodyPtr, m_body->size())); + + return response; + } +}; + +class TestHttpTransport final : public HttpTransport { +private: + std::shared_ptr> m_requests; + std::vector m_responses; + decltype(m_responses)::size_type m_currentResponse; + +public: + explicit TestHttpTransport( + std::shared_ptr> requests, + std::vector responses) + : m_requests(std::move(requests)), m_responses(std::move(responses)), m_currentResponse(0) + { + } + + std::unique_ptr Send(Request& request, Context const&) override + { + EXPECT_LT(m_currentResponse, m_responses.size()); + + m_requests->emplace_back(TestRequest(request)); + return m_responses.at(m_currentResponse++).CreateRawResponse(); + } +}; + +class TestKeyVaultClient final { + std::shared_ptr m_pipeline; + Url m_vaultUrl; + +public: + explicit TestKeyVaultClient( + std::string vaultUrl, + std::shared_ptr credential, + std::shared_ptr testHttpTransport) + : m_vaultUrl(vaultUrl) + { + using Azure::Core::Http::Policies::HttpPolicy; + using Azure::Security::KeyVault::_internal::UrlScope; + + ClientOptions options; + options.Transport.Transport = testHttpTransport; + + Azure::Core::Credentials::TokenRequestContext tokenContext; + tokenContext.Scopes = {UrlScope::GetScopeFromUrl(m_vaultUrl)}; + + std::vector> perRetryPolicies; + perRetryPolicies.emplace_back( + std::make_unique(credential, tokenContext)); + + std::vector> perCallPolicies; + + m_pipeline = std::make_shared( + options, + "TestKeyVaultClient", + "1.0.0", + std::move(perRetryPolicies), + std::move(perCallPolicies)); + } + + std::unique_ptr DoSomething(Context const& context = {}) const + { + using Azure::Core::Http::HttpMethod; + auto request = Request(HttpMethod::Get, m_vaultUrl); + return m_pipeline->Send(request, context); + } +}; + +std::shared_ptr CreateTestCredential( + std::shared_ptr testHttpTransport, + decltype(ClientSecretCredentialOptions::AdditionallyAllowedTenants) additionallyAllowedTenants + = {}) +{ + using Azure::Identity::ClientSecretCredential; + + ClientSecretCredentialOptions options; + options.Transport.Transport = testHttpTransport; + options.AdditionallyAllowedTenants = additionallyAllowedTenants; + + return std::make_shared( + "OriginalTenantId", "ClientId", "ClientSecret", options); +} + +std::string GetTenantIdFromClientSecretRequest(TestRequest const& request) +{ + auto const urlPath = request.Url.GetPath(); + auto const slashPos = urlPath.find('/'); + return (slashPos != std::string::npos) ? urlPath.substr(0, slashPos) : urlPath; +} + +std::string GetScopeFromClientSecretRequest(TestRequest const& request) +{ + std::string const ScopeParam = "scope="; + auto const scopeParamStart = request.Body.find(ScopeParam); + if (scopeParamStart == std::string::npos) + { + return {}; + } + + auto const scopeValueStart = scopeParamStart + ScopeParam.length(); + auto const nextParamStart = request.Body.find('&', scopeValueStart); + + auto const scopeValueEnd + = (nextParamStart != std::string::npos) ? nextParamStart : request.Body.length(); + + return request.Body.substr(scopeValueStart, scopeValueEnd); +} + +std::string GetAuthHeaderValueFromServiceRequest(TestRequest const& request) +{ + auto const authHeaderIter = request.Headers.find("authorization"); + return (authHeaderIter != request.Headers.end()) ? authHeaderIter->second : std::string{}; +} +} // namespace + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, BearerTokenAuthPolicyCompatible) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherScopeAsScope) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net/OriginalTenantId\"," + " scope=\"https://test.vault.azure.net/.default\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherScopeAsResource) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net/OriginalTenantId\"," + " resource=\"https://test.vault.azure.net\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherTenantAsterisk) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + }), + {"*"}), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net/NewTenantId\"," + " resource=\"https://vault.azure.net\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "NewTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherTenantAndScopeWithAltNames) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + }), + {"*"}), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization_uri=\"https://login.windows.net/NewTenantId/\"," + " scope=\"https://test.vault.azure.net/.default\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "NewTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherTenantExplicit) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + }), + {"NewTenantId"}), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net/NewTenantId\"," + " resource=\"https://vault.azure.net\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "NewTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AnotherTenantNotAllowed) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + }), + {"UnknownTenantId"}), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net/NewTenantId\"," + " resource=\"https://vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, MissingScope) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + auto const serviceResponse + = TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization=\"https://login.windows.net/OriginalTenantId\""), + }), + })) + .DoSomething(); + + EXPECT_TRUE(serviceResponse); + { + EXPECT_EQ(serviceResponse->GetStatusCode(), HttpStatusCode::Unauthorized); + + auto const& responseHeaders = serviceResponse->GetHeaders(); + auto const authHeader = responseHeaders.find("WWW-Authenticate"); + EXPECT_NE(authHeader, responseHeaders.end()); + EXPECT_EQ( + authHeader->second, "Bearer authorization=\"https://login.windows.net/OriginalTenantId\""); + } + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, EmptyScope) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + auto const serviceResponse + = TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization=\"https://login.windows.net/OriginalTenantId\"," + " scope=\"\""), + }), + })) + .DoSomething(); + + EXPECT_TRUE(serviceResponse); + { + EXPECT_EQ(serviceResponse->GetStatusCode(), HttpStatusCode::Unauthorized); + + auto const& responseHeaders = serviceResponse->GetHeaders(); + auto const authHeader = responseHeaders.find("WWW-Authenticate"); + EXPECT_NE(authHeader, responseHeaders.end()); + EXPECT_EQ( + authHeader->second, + "Bearer authorization=\"https://login.windows.net/OriginalTenantId\", scope=\"\""); + } + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, ScopeValidationInvalidUrl) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer " + "authorization=\"https://login.windows.net/OriginalTenantId\"," + " resource=\"nonparseable_url\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, ScopeValidationLongerDomain) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer " + "authorization=\"https://login.windows.net/OriginalTenantId\"," + " resource=\"longer.test.vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, ScopeValidationDomainMismatch) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer " + "authorization=\"https://login.windows.net/OriginalTenantId\"," + " resource=\"vault.azure.com\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AuthorizationMissing) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", "Bearer resource=\"vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AuthorizationEmpty) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"\", resource=\"vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AuthorizationInvalidUrl) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"nonparseable_url\"," + " resource=\"vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AuthorizationEmptyPath) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + EXPECT_THROW( + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential(std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN\"}", + {}), + })), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer authorization=\"https://login.windows.net\"," + " resource=\"vault.azure.net\""), + }), + })) + .DoSomething()), + AuthenticationException); + + EXPECT_EQ(identityRequests->size(), 1); + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + EXPECT_EQ(serviceRequests->size(), 1); + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN"); + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, AuthorizationLongerPath) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + static_cast( // + TestKeyVaultClient( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + }), + {"*"}), + std::make_shared( + serviceRequests, + std::vector{ + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization=\"https://login.windows.net/NewTenantId/whatever\"," + " scope=\"https://test.vault.azure.net/.default\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), + })) + .DoSomething()); + + EXPECT_EQ(identityRequests->size(), 2); + { + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "NewTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + } + + EXPECT_EQ(serviceRequests->size(), 2); + { + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN2"); + } + } +} + +TEST(KeyVaultChallengeBasedAuthenticationPolicy, MultipleTimes) +{ + auto identityRequests = std::make_shared>(); + auto serviceRequests = std::make_shared>(); + + TestKeyVaultClient client( + "https://test.vault.azure.net", + CreateTestCredential( + std::make_shared( + identityRequests, + std::vector{ + TestResponse( // <-- DoSomething() #1 + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN1\"}", + {}), + TestResponse( // <-- DoSomething() #2 + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN2\"}", + {}), + TestResponse( // <-- DoSomething() #4 + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN3\"}", + {}), + TestResponse( // <-- DoSomething() #7 + HttpStatusCode::Ok, + "{\"expires_in\":3600,\"access_token\":\"ACCESSTOKEN4\"}", + {}), + }), + {"*"}), + std::make_shared( + serviceRequests, + std::vector{ + // DoSomething() #1 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), // OriginalTenantId, TOKEN1 + + // DoSomething() #2 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization=\"https://login.windows.net/NewTenantId/whatever\"," + " scope=\"https://test.vault.azure.net/.default\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), // NewTenantId, TOKEN2 + + // DoSomething() #3 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), + + // DoSomething() #4 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization_uri=\"https://login.windows.net/AnotherTenantId\"," + " resource=\"https://test.vault.azure.net/\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), // AnotherTenantId (test.vault...), TOKEN3 + + // DoSomething() #5 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), + + // DoSomething() #6 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), + + // DoSomething() #7 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization_uri=\"https://login.windows.net/AnotherTenantId\"," + " resource=\"https://vault.azure.net\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), // AnotherTenantId (vault.azure...), TOKEN4 + + // DoSomething() #8 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", "Bearer resource=\"https://vault.azure.net\""), + }), // ^^^ authorization_uri is missing - throws + + // DoSomething() #9 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), // AnotherTenantId (vault.azure...), TOKEN4 + + // DoSomething() #10 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization_uri=\"https://login.windows.net/OriginalTenantId\"," + " resource=\"https://vault.azure.net/\""), + }), + TestResponse(HttpStatusCode::Ok, {}, {}), // OriginalTenantId, cached TOKEN2 + + // DoSomething() #11 vvvvv + TestResponse( + HttpStatusCode::Unauthorized, + {}, + { + std::make_pair( + "WWW-Authenticate", + "Bearer" + " authorization=\"https://login.windows.net/NewTenantId\""), + }), // ^^^ resource is missing - won't update token + + // DoSomething() #12 vvvvv + TestResponse(HttpStatusCode::Ok, {}, {}), // OriginalTenantId, TOKEN5 + })); + + client.DoSomething(); // #1: Ok with defaults, authorize with TOKEN1 + client.DoSomething(); // #2: Challenge response, NewTenantId, new scope, authorize with TOKEN2 + client.DoSomething(); // #3: Ok, authorize with TOKEN2 + client.DoSomething(); // #4: Challenge response, AnotherTenantId, same scope, auth with TOKEN3 + client.DoSomething(); // #5: Ok, authorize with TOKEN3 + client.DoSomething(); // #6: Ok, authorize with TOKEN3 + client.DoSomething(); // #7: Challenge response, same TenantId, new scope, authorize with TOKEN4 + EXPECT_THROW(client.DoSomething(), AuthenticationException); // #8: Bad challenge (no TenantId) + client.DoSomething(); // #9: Ok, keeps authorizing with TOKEN4 + client.DoSomething(); // #10: Revert back to OriginalTokenId, use cached TOKEN1 + client.DoSomething(); // #11: Attempt NewTenantId, but scope is missing + client.DoSomething(); // #12: Ok, authorize with TOKEN1 + + EXPECT_EQ(identityRequests->size(), 4); + { + // DoSomething() #1 vvv + { + auto const& identityRequest0 = identityRequests->at(0); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest0), "OriginalTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest0), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + // DoSomething() #2 vvv + { + auto const& identityRequest1 = identityRequests->at(1); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest1), "NewTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest1), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + + // DoSomething() #4 vvv + { + auto const& identityRequest2 = identityRequests->at(2); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest2), "AnotherTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest2), + "https%3A%2F%2Ftest.vault.azure.net%2F.default"); + } + + // DoSomething() #7 vvv + { + auto const& identityRequest3 = identityRequests->at(3); + EXPECT_EQ(GetTenantIdFromClientSecretRequest(identityRequest3), "AnotherTenantId"); + EXPECT_EQ( + GetScopeFromClientSecretRequest(identityRequest3), + "https%3A%2F%2Fvault.azure.net%2F.default"); + } + + // DoSomething() #10 won't make a request because the token is cached + } + + EXPECT_EQ(serviceRequests->size(), 16); + { + // DoSomething() #1 vvv + { + auto const& serviceRequest0 = serviceRequests->at(0); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest0), "Bearer ACCESSTOKEN1"); + } + + // DoSomething() #2 vvv + { + auto const& serviceRequest1 = serviceRequests->at(1); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest1), "Bearer ACCESSTOKEN1"); + } + + { + auto const& serviceRequest2 = serviceRequests->at(2); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest2), "Bearer ACCESSTOKEN2"); + } + + // DoSomething() #3 vvv + { + auto const& serviceRequest3 = serviceRequests->at(3); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest3), "Bearer ACCESSTOKEN2"); + } + + // DoSomething() #4 vvv + { + auto const& serviceRequest4 = serviceRequests->at(4); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest4), "Bearer ACCESSTOKEN2"); + } + + { + auto const& serviceRequest5 = serviceRequests->at(5); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest5), "Bearer ACCESSTOKEN3"); + } + + // DoSomething() #5 vvv + { + auto const& serviceRequest6 = serviceRequests->at(6); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest6), "Bearer ACCESSTOKEN3"); + } + + // DoSomething() #6 vvv + { + auto const& serviceRequest7 = serviceRequests->at(7); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest7), "Bearer ACCESSTOKEN3"); + } + + // DoSomething() #7 vvv + { + auto const& serviceRequest8 = serviceRequests->at(8); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest8), "Bearer ACCESSTOKEN3"); + } + + { + auto const& serviceRequest9 = serviceRequests->at(9); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest9), "Bearer ACCESSTOKEN4"); + } + + // DoSomething() #8 vvv + { + auto const& serviceRequest10 = serviceRequests->at(10); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest10), "Bearer ACCESSTOKEN4"); + } + + // DoSomething() #9 vvv + { + auto const& serviceRequest11 = serviceRequests->at(11); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest11), "Bearer ACCESSTOKEN4"); + } + + // DoSomething() #10 vvv + { + auto const& serviceRequest12 = serviceRequests->at(12); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest12), "Bearer ACCESSTOKEN4"); + } + + { + auto const& serviceRequest13 = serviceRequests->at(13); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest13), "Bearer ACCESSTOKEN1"); + } + + // DoSomething() #11 vvv + { + auto const& serviceRequest14 = serviceRequests->at(14); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest14), "Bearer ACCESSTOKEN1"); + } + + // DoSomething() #12 vvv + { + auto const& serviceRequest15 = serviceRequests->at(15); + EXPECT_EQ(GetAuthHeaderValueFromServiceRequest(serviceRequest15), "Bearer ACCESSTOKEN1"); + } + } +} diff --git a/sdk/keyvault/azure-security-keyvault-secrets/vcpkg/vcpkg.json b/sdk/keyvault/azure-security-keyvault-secrets/vcpkg/vcpkg.json index 07c978c24..4f192a722 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/vcpkg/vcpkg.json +++ b/sdk/keyvault/azure-security-keyvault-secrets/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.5.0" + "version>=": "1.9.0-beta.1" }, { "name": "vcpkg-cmake", diff --git a/sdk/keyvault/azure-security-keyvault-shared/inc/azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp b/sdk/keyvault/azure-security-keyvault-shared/inc/azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp new file mode 100644 index 000000000..b843116a3 --- /dev/null +++ b/sdk/keyvault/azure-security-keyvault-shared/inc/azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Key Vault Challenge-Based Authentication Policy. + * + */ + +#pragma once + +#include +#include + +#include + +namespace Azure { namespace Security { namespace KeyVault { namespace _internal { + /** + * @brief Challenge-Based Authentication Policy for Key Vault. + * + */ + class KeyVaultChallengeBasedAuthenticationPolicy final + : public Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy { + private: + mutable Core::Credentials::TokenRequestContext m_tokenRequestContext; + + public: + explicit KeyVaultChallengeBasedAuthenticationPolicy( + std::shared_ptr credential, + Core::Credentials::TokenRequestContext tokenRequestContext) + : BearerTokenAuthenticationPolicy(credential, tokenRequestContext), + m_tokenRequestContext(tokenRequestContext) + { + } + + std::unique_ptr Clone() const override + { + return std::make_unique(*this); + } + + private: + std::unique_ptr AuthorizeAndSendRequest( + Core::Http::Request& request, + Core::Http::Policies::NextHttpPolicy& nextPolicy, + Core::Context const& context) const override + { + AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context); + return nextPolicy.Send(request, context); + } + + bool AuthorizeRequestOnChallenge( + std::string const& challenge, + Core::Http::Request& request, + Core::Context const& context) const override + { + auto const scope = GetScope(challenge); + if (scope.empty()) + { + return false; + } + + ValidateChallengeResponse(scope, request.GetUrl().GetHost()); + + auto const tenantId = GetTenantId(GetAuthorization(challenge)); + m_tokenRequestContext.TenantId = tenantId; + m_tokenRequestContext.Scopes = {scope}; + + AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context); + return true; + } + + static std::string TrimTrailingSlash(std::string const& s) + { + return (s.empty() || s.back() != '/') ? s : s.substr(0, s.size() - 1); + } + + static std::string GetScope(std::string const& challenge) + { + using Core::Credentials::_internal::AuthorizationChallengeParser; + + auto resource + = AuthorizationChallengeParser::GetChallengeParameter(challenge, "Bearer", "resource"); + + return !resource.empty() + ? (TrimTrailingSlash(resource) + "/.default") + : AuthorizationChallengeParser::GetChallengeParameter(challenge, "Bearer", "scope"); + } + + static std::string GetAuthorization(std::string const& challenge) + { + using Core::Credentials::_internal::AuthorizationChallengeParser; + + auto authorization = AuthorizationChallengeParser::GetChallengeParameter( + challenge, "Bearer", "authorization"); + + return !authorization.empty() ? authorization + : AuthorizationChallengeParser::GetChallengeParameter( + challenge, "Bearer", "authorization_uri"); + } + + static bool TryParseUrl(std::string const& s, Core::Url& outUrl) + { + using Core::Url; + try + { + outUrl = Url(s); + } + catch (std::out_of_range const&) + { + return false; + } + catch (std::invalid_argument const&) + { + return false; + } + + return true; + } + + static void ValidateChallengeResponse(std::string const& scope, std::string const& requestHost) + { + using Core::Url; + using Core::Credentials::AuthenticationException; + + Url scopeUrl; + if (!TryParseUrl(scope, scopeUrl)) + { + throw AuthenticationException("The challenge contains invalid scope '" + scope + "'."); + } + + auto const& scopeHost = scopeUrl.GetHost(); + + // Check whether requestHost.ends_with(scopeHost) + auto const requestHostLength = requestHost.length(); + auto const scopeHostLength = scopeHost.length(); + + bool domainMismatch = requestHostLength < scopeHostLength; + if (!domainMismatch) + { + auto const requestHostOffset = requestHostLength - scopeHostLength; + for (size_t i = 0; i < scopeHostLength; ++i) + { + if (requestHost[requestHostOffset + i] != scopeHost[i]) + { + domainMismatch = true; + break; + } + } + } + + if (domainMismatch) + { + throw AuthenticationException( + "The challenge resource '" + scopeHost + "' does not match the requested domain."); + } + } + + static std::string GetTenantId(std::string const& authorization) + { + using Core::Url; + using Core::Credentials::AuthenticationException; + + if (!authorization.empty()) + { + Url authorizationUrl; + if (TryParseUrl(authorization, authorizationUrl)) + { + auto const& path = authorizationUrl.GetPath(); + if (!path.empty()) + { + auto const firstSlash = path.find('/'); + if (firstSlash == std::string::npos) + { + return path; + } + else if (firstSlash > 0) + { + return path.substr(0, firstSlash); + } + } + } + } + + throw AuthenticationException( + "The challenge authorization URI '" + authorization + "' is invalid."); + } + }; +}}}} // namespace Azure::Security::KeyVault::_internal