Add support for challenge-based and multi-tenant authentication (#4506)

* Add support for challenge-based and multi-tenant authentication

* Clang-format

* cspell

* clang-format

* gcc warning

* clang warning

* Remove debug message

* clang-format

* update version>= in vcpkg manifests

* unpublic copy ctor in polymorphic class

* KeyVault::_internal::ChallengeBasedAuthenticationPolicy => KeyVault::_internal::KeyVaultChallengeBasedAuthenticationPolicy

* keyvault/shared/challenge_based_authentication_policy.hpp => keyvault/shared/keyvault_challenge_based_authentication_policy.hpp

---------

Co-authored-by: Anton Kolesnyk <antkmsft@users.noreply.github.com>
This commit is contained in:
Anton Kolesnyk 2023-04-04 11:00:05 -07:00 committed by GitHub
parent 02bb09aac1
commit fd687c32fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 3194 additions and 331 deletions

View File

@ -7,6 +7,7 @@
- Added the ability to ignore invalid certificate common name for TLS connections in WinHTTP transport.
- Added `DisableTlsCertificateValidation` in `TransportOptions`.
- Added `TokenCredential::GetCredentialName()` to be utilized in diagnostic messages. If you have any custom implementations of `TokenCredential`, it is recommended to pass the name of your credential to `TokenCredential` constructor. The old parameterless constructor is deprecated.
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes

View File

@ -83,6 +83,7 @@ set(
inc/azure/core/http/policies/policy.hpp
inc/azure/core/http/raw_response.hpp
inc/azure/core/http/transport.hpp
inc/azure/core/internal/credentials/authorization_challenge_parser.hpp
inc/azure/core/internal/client_options.hpp
inc/azure/core/internal/contract.hpp
inc/azure/core/internal/cryptography/sha_hash.hpp
@ -121,6 +122,7 @@ set(
src/azure_assert.cpp
src/base64.cpp
src/context.cpp
src/credentials/authorization_challenge_parser.cpp
src/cryptography/md5.cpp
src/cryptography/sha_hash.cpp
src/datetime.cpp

View File

@ -54,6 +54,12 @@ namespace Azure { namespace Core { namespace Credentials {
*
*/
DateTime::duration MinimumExpiration = std::chrono::minutes(2);
/**
* @brief Tenant ID.
*
*/
std::string TenantId;
};
/**

View File

@ -543,16 +543,14 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
* @brief Bearer Token authentication policy.
*
*/
class BearerTokenAuthenticationPolicy final : public HttpPolicy {
class BearerTokenAuthenticationPolicy : public HttpPolicy {
private:
std::shared_ptr<Credentials::TokenCredential const> const m_credential;
Credentials::TokenRequestContext m_tokenRequestContext;
mutable Credentials::AccessToken m_accessToken;
mutable std::mutex m_accessTokenMutex;
BearerTokenAuthenticationPolicy(BearerTokenAuthenticationPolicy const&) = delete;
void operator=(BearerTokenAuthenticationPolicy const&) = delete;
mutable Credentials::TokenRequestContext m_accessTokenContext;
public:
/**
@ -571,14 +569,36 @@ namespace Azure { namespace Core { namespace Http { namespace Policies {
std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<BearerTokenAuthenticationPolicy>(
m_credential, m_tokenRequestContext);
return std::unique_ptr<HttpPolicy>(new BearerTokenAuthenticationPolicy(*this));
}
std::unique_ptr<RawResponse> Send(
Request& request,
NextHttpPolicy nextPolicy,
Context const& context) const override;
protected:
BearerTokenAuthenticationPolicy(BearerTokenAuthenticationPolicy const& other)
: BearerTokenAuthenticationPolicy(other.m_credential, other.m_tokenRequestContext)
{
}
void operator=(BearerTokenAuthenticationPolicy const&) = delete;
virtual std::unique_ptr<RawResponse> AuthorizeAndSendRequest(
Request& request,
NextHttpPolicy& nextPolicy,
Context const& context) const;
virtual bool AuthorizeRequestOnChallenge(
std::string const& challenge,
Request& request,
Context const& context) const;
void AuthenticateAndAuthorizeRequest(
Request& request,
Credentials::TokenRequestContext const& tokenRequestContext,
Context const& context) const;
};
/**

View File

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Parser for challenge-based authentication policy.
*/
#pragma once
#include "azure/core/http/raw_response.hpp"
#include <string>
namespace Azure { namespace Core { namespace Credentials {
namespace _detail {
class AuthorizationChallengeHelper final {
public:
static std::string const& GetChallenge(Http::RawResponse const& response);
};
} // namespace _detail
namespace _internal {
class AuthorizationChallengeParser final {
private:
AuthorizationChallengeParser() = delete;
~AuthorizationChallengeParser() = delete;
public:
/**
* @brief Gets challenge parameter from authentication challenge.
*
* @param challenge Authentication challenge.
* @param challengeScheme The challenge scheme containing the \p challengeParameter.
* @param challengeParameter The parameter key name containing the value to return.
*
* @return Challenge parameter value.
*/
static std::string GetChallengeParameter(
std::string const& challenge,
std::string const& challengeScheme,
std::string const& challengeParameter);
};
} // namespace _internal
}}} // namespace Azure::Core::Credentials

View File

@ -0,0 +1,274 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/internal/credentials/authorization_challenge_parser.hpp"
#include "azure/core/internal/strings.hpp"
#include <set>
using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper;
using Azure::Core::Credentials::_internal::AuthorizationChallengeParser;
using Azure::Core::_internal::StringExtensions;
using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;
// The parser is implemented to closely mimic the logic in .NET SDK:
// https://github.com/Azure/azure-sdk-for-net/blob/b7efe81fe69e020df74853b0f501ca6ccd5c94a1/sdk/core/Azure.Core/src/Shared/AuthorizationChallengeParser.cs
namespace {
class StringSpan final {
std::string const* m_stringPtr;
int m_startPos;
int m_endPosExclusive;
public:
StringSpan(std::string const* stringPtr = nullptr);
StringSpan(StringSpan const&) = default;
StringSpan& operator=(StringSpan const&) = default;
int Length() const;
std::string ToString() const;
StringSpan Slice(int start) const;
StringSpan Slice(int start, int length) const;
StringSpan TrimStart(std::set<char> const& chars) const;
StringSpan Trim(std::set<char> const& chars) const;
int IndexOfAny(std::set<char> const& chars) const;
bool CaseInsensitiveEquals(StringSpan const& other) const;
};
bool TryGetNextChallenge(StringSpan& headerValue, StringSpan& challengeKey);
bool TryGetNextParameter(StringSpan& headerValue, StringSpan& paramKey, StringSpan& paramValue);
std::string const EmptyString;
} // namespace
std::string const& AuthorizationChallengeHelper::GetChallenge(Http::RawResponse const& response)
{
// See RFC7235 (https://www.rfc-editor.org/rfc/rfc7235#section-4.1)
if (response.GetStatusCode() == HttpStatusCode::Unauthorized)
{
auto const& headers = response.GetHeaders();
auto const wwwAuthHeader = headers.find("WWW-Authenticate");
if (wwwAuthHeader != headers.end())
{
return wwwAuthHeader->second;
}
}
return EmptyString;
}
std::string AuthorizationChallengeParser::GetChallengeParameter(
std::string const& challenge,
std::string const& challengeScheme,
std::string const& challengeParameter)
{
StringSpan bearer = &challengeScheme;
StringSpan claims = &challengeParameter;
StringSpan headerSpan = &challenge;
// Iterate through each challenge value.
StringSpan challengeKey;
while (TryGetNextChallenge(headerSpan, challengeKey))
{
// Enumerate each key=value parameter until we find the 'claims' key on the 'Bearer'
// challenge.
StringSpan key;
StringSpan value;
while (TryGetNextParameter(headerSpan, key, value))
{
if (challengeKey.CaseInsensitiveEquals(bearer) && key.CaseInsensitiveEquals(claims))
{
return value.ToString();
}
}
}
return {};
}
namespace {
std::set<char> const Space = {' '};
std::set<char> const SpaceOrComma = {' ', ','};
std::set<char> const Separator = {'='};
std::set<char> const Quote = {'\"'};
bool TryGetNextChallenge(StringSpan& headerValue, StringSpan& challengeKey)
{
challengeKey = StringSpan();
headerValue = headerValue.TrimStart(Space);
int endOfChallengeKey = headerValue.IndexOfAny(Space);
if (endOfChallengeKey < 0)
{
return false;
}
challengeKey = headerValue.Slice(0, endOfChallengeKey);
// Slice the challenge key from the headerValue
headerValue = headerValue.Slice(endOfChallengeKey + 1);
return true;
}
bool TryGetNextParameter(StringSpan& headerValue, StringSpan& paramKey, StringSpan& paramValue)
{
paramKey = StringSpan();
paramValue = StringSpan();
// Trim any separator prefixes.
headerValue = headerValue.TrimStart(SpaceOrComma);
int nextSpace = headerValue.IndexOfAny(Space);
int nextSeparator = headerValue.IndexOfAny(Separator);
if (nextSpace < nextSeparator && nextSpace != -1)
{
// we encountered another challenge value.
return false;
}
if (nextSeparator < 0)
{
return false;
}
// Get the paramKey.
paramKey = headerValue.Slice(0, nextSeparator).Trim(Space);
// Slice to remove the 'paramKey=' from the parameters.
headerValue = headerValue.Slice(nextSeparator + 1);
// The start of paramValue will usually be a quoted string. Find the first quote.
int quoteIndex = headerValue.IndexOfAny(Quote);
// Get the paramValue, which is delimited by the trailing quote.
headerValue = headerValue.Slice(quoteIndex + 1);
if (quoteIndex >= 0)
{
// The values are quote wrapped
paramValue = headerValue.Slice(0, headerValue.IndexOfAny(Quote));
}
else
{
// the values are not quote wrapped (storage is one example of this)
// either find the next space indicating the delimiter to the next value, or go to the end
// since this is the last value.
int trailingDelimiterIndex = headerValue.IndexOfAny(SpaceOrComma);
if (trailingDelimiterIndex >= 0)
{
paramValue = headerValue.Slice(0, trailingDelimiterIndex);
}
else
{
paramValue = headerValue;
}
}
// Slice to remove the '"paramValue"' from the parameters.
if (!headerValue.CaseInsensitiveEquals(paramValue))
headerValue = headerValue.Slice(paramValue.Length() + 1);
return true;
}
StringSpan::StringSpan(std::string const* stringPtr)
: m_stringPtr(stringPtr), m_startPos(0),
m_endPosExclusive(stringPtr ? static_cast<int>(stringPtr->size()) : 0)
{
}
int StringSpan::Length() const { return m_endPosExclusive - m_startPos; }
std::string StringSpan::ToString() const
{
return m_stringPtr
? m_stringPtr->substr(static_cast<size_t>(m_startPos), static_cast<size_t>(Length()))
: std::string{};
}
StringSpan StringSpan::Slice(int start) const
{
StringSpan result = *this;
result.m_startPos += start;
return result;
}
StringSpan StringSpan::Slice(int start, int length) const
{
auto result = Slice(start);
result.m_endPosExclusive = result.m_startPos + length;
return result;
}
StringSpan StringSpan::TrimStart(std::set<char> const& chars) const
{
StringSpan result = *this;
auto pos = result.m_startPos;
for (; pos < result.m_endPosExclusive; ++pos)
{
if (chars.find(result.m_stringPtr->operator[](static_cast<size_t>(pos))) == chars.end())
{
break;
}
}
result.m_startPos = pos;
return result;
}
StringSpan StringSpan::Trim(std::set<char> const& chars) const
{
StringSpan result = TrimStart(chars);
auto endPos = m_endPosExclusive;
for (; endPos > result.m_startPos; --endPos)
{
if (chars.find(result.m_stringPtr->operator[](static_cast<size_t>(endPos - 1))) == chars.end())
{
break;
}
}
result.m_endPosExclusive = endPos;
return result;
}
int StringSpan::IndexOfAny(std::set<char> const& chars) const
{
for (auto pos = m_startPos; pos < m_endPosExclusive; ++pos)
{
if (chars.find(m_stringPtr->operator[](static_cast<size_t>(pos))) != chars.end())
{
return pos - m_startPos;
}
}
return -1;
}
bool StringSpan::CaseInsensitiveEquals(StringSpan const& other) const
{
auto const length = Length();
if (length != other.Length())
{
return false;
}
for (auto offset = 0; offset < length; ++offset)
{
if (StringExtensions::ToLower(m_stringPtr->operator[](static_cast<size_t>(m_startPos + offset)))
!= StringExtensions::ToLower(
other.m_stringPtr->operator[](static_cast<size_t>(other.m_startPos + offset))))
{
return false;
}
}
return true;
}
} // namespace

View File

@ -1,16 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/credentials/credentials.hpp"
#include "azure/core/http/policies/policy.hpp"
#include "azure/core/credentials/credentials.hpp"
#include "azure/core/internal/credentials/authorization_challenge_parser.hpp"
#include <chrono>
using Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy;
using Azure::Core::Context;
using namespace Azure::Core::Http;
using namespace Azure::Core::Http::Policies;
using namespace Azure::Core::Http::Policies::_internal;
using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper;
using Azure::Core::Http::RawResponse;
using Azure::Core::Http::Policies::NextHttpPolicy;
std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
Request& request,
@ -23,17 +28,55 @@ std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::Send(
"Bearer token authentication is not permitted for non TLS protected (https) endpoints.");
}
auto result = AuthorizeAndSendRequest(request, nextPolicy, context);
{
std::lock_guard<std::mutex> lock(m_accessTokenMutex);
if (std::chrono::system_clock::now()
> (m_accessToken.ExpiresOn - m_tokenRequestContext.MinimumExpiration))
auto const& response = *result;
auto const& challenge = AuthorizationChallengeHelper::GetChallenge(response);
if (!challenge.empty() && AuthorizeRequestOnChallenge(challenge, request, context))
{
m_accessToken = m_credential->GetToken(m_tokenRequestContext, context);
result = nextPolicy.Send(request, context);
}
request.SetHeader("authorization", "Bearer " + m_accessToken.Token);
}
return result;
}
std::unique_ptr<RawResponse> BearerTokenAuthenticationPolicy::AuthorizeAndSendRequest(
Request& request,
NextHttpPolicy& nextPolicy,
Context const& context) const
{
AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context);
return nextPolicy.Send(request, context);
}
bool BearerTokenAuthenticationPolicy::AuthorizeRequestOnChallenge(
std::string const& challenge,
Request& request,
Context const& context) const
{
static_cast<void>(challenge);
static_cast<void>(request);
static_cast<void>(context);
return false;
}
void BearerTokenAuthenticationPolicy::AuthenticateAndAuthorizeRequest(
Request& request,
Credentials::TokenRequestContext const& tokenRequestContext,
Context const& context) const
{
std::lock_guard<std::mutex> lock(m_accessTokenMutex);
if (tokenRequestContext.TenantId != m_accessTokenContext.TenantId
|| tokenRequestContext.Scopes != m_accessTokenContext.Scopes
|| std::chrono::system_clock::now()
> (m_accessToken.ExpiresOn - tokenRequestContext.MinimumExpiration))
{
m_accessToken = m_credential->GetToken(tokenRequestContext, context);
m_accessTokenContext = tokenRequestContext;
}
request.SetHeader("authorization", "Bearer " + m_accessToken.Token);
}

View File

@ -42,6 +42,7 @@ endif()
add_executable (
azure-core-test
authorization_challenge_parser_test.cpp
azure_core_test.cpp
base64_test.cpp
bearer_token_authentication_policy_test.cpp

View File

@ -0,0 +1,207 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include <azure/core/internal/credentials/authorization_challenge_parser.hpp>
#include <gtest/gtest.h>
using Azure::Core::Credentials::_detail::AuthorizationChallengeHelper;
using Azure::Core::Credentials::_internal::AuthorizationChallengeParser;
using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;
namespace {
RawResponse CreateRawResponseWithWwwAuthHeader(
std::string const& value,
HttpStatusCode httpStatusCode = HttpStatusCode::Unauthorized)
{
RawResponse result(1, 1, httpStatusCode, "Test");
result.SetHeader("WWW-Authenticate", value);
return result;
}
std::string GetChallengeParameterFromResponse(
RawResponse const& response,
std::string const& challengeScheme,
std::string const& challengeParameter)
{
return AuthorizationChallengeParser::GetChallengeParameter(
AuthorizationChallengeHelper::GetChallenge(response), challengeScheme, challengeParameter);
}
} // namespace
TEST(AuthorizationChallengeParser, Simple)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Bearer key=value"), "Bearer", "key"),
"value");
}
TEST(AuthorizationChallengeParser, EmptyString)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(CreateRawResponseWithWwwAuthHeader(""), "Bearer", "key"),
"");
}
TEST(AuthorizationChallengeParser, Non401)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Bearer key=value", HttpStatusCode::Ok),
"Bearer",
"key"),
"");
}
TEST(AuthorizationChallengeParser, NoHeader)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
RawResponse(1, 1, HttpStatusCode::Unauthorized, "Test"), "Bearer", "key"),
"");
}
TEST(AuthorizationChallengeParser, KeyNotFound)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Bearer otherkey=value"), "Bearer", "key"),
"");
}
TEST(AuthorizationChallengeParser, SchemeNotFound)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Basic key=value"), "Bearer", "key"),
"");
}
TEST(AuthorizationChallengeParser, NotFoundForScheme)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Basic key=value, Bearer otherkey=value"),
"Bearer",
"key"),
"");
}
TEST(AuthorizationChallengeParser, MultiplSchemeMatch)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader(
"Basic key=value1, Bearer key=value2, Digest key=value3"),
"Bearer",
"key"),
"value2");
}
TEST(AuthorizationChallengeParser, Quoted)
{
EXPECT_EQ(
GetChallengeParameterFromResponse(
CreateRawResponseWithWwwAuthHeader("Bearer key=\"v a l u e\""), "Bearer", "key"),
"v a l u e");
}
TEST(AuthorizationChallengeParser, CaeInsufficientClaimsChallenge)
{
auto const response = CreateRawResponseWithWwwAuthHeader(
"Bearer realm=\"\", "
"authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", "
"client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", "
"claims=\"eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=\"");
EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "realm"), "");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"),
"https://login.microsoftonline.com/common/oauth2/authorize");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "client_id"),
"00000003-0000-0000-c000-000000000000");
EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "insufficient_claims");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "claims"),
"eyJhY2Nlc3NfdG9rZW4iOiB7ImZvbyI6ICJiYXIifX0=");
}
TEST(AuthorizationChallengeParser, CaeSessionsRevokedClaimsChallenge)
{
auto const response = CreateRawResponseWithWwwAuthHeader(
"Bearer authorization_uri=\"https://login.windows-ppe.net/\", error=\"invalid_token\", "
"error_description=\"User session has been revoked\", "
"claims="
"\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="
"\"");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"),
"https://login.windows-ppe.net/");
EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "invalid_token");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "error_description"),
"User session has been revoked");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "claims"),
"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=");
}
TEST(AuthorizationChallengeParser, KeyVaultChallenge)
{
auto const response = CreateRawResponseWithWwwAuthHeader(
"Bearer "
"authorization=\"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47\", "
"resource=\"https://vault.azure.net\"");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "authorization"),
"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "resource"), "https://vault.azure.net");
}
TEST(AuthorizationChallengeParser, ArmChallenge)
{
auto const response = CreateRawResponseWithWwwAuthHeader(
"Bearer authorization_uri=\"https://login.windows.net/\", error=\"invalid_token\", "
"error_description=\"The authentication failed because of missing 'Authorization' header.\"");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"),
"https://login.windows.net/");
EXPECT_EQ(GetChallengeParameterFromResponse(response, "Bearer", "error"), "invalid_token");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "error_description"),
"The authentication failed because of missing 'Authorization' header.");
}
TEST(AuthorizationChallengeParser, StorageChallenge)
{
auto const response = CreateRawResponseWithWwwAuthHeader(
"Bearer "
"authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/"
"oauth2/authorize resource_id=https://storage.azure.com");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "authorization_uri"),
"https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize");
EXPECT_EQ(
GetChallengeParameterFromResponse(response, "Bearer", "resource_id"),
"https://storage.azure.com");
}

View File

@ -8,34 +8,44 @@
#include <gtest/gtest.h>
using Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy;
using Azure::Core::Context;
using Azure::Core::Url;
using Azure::Core::Credentials::AccessToken;
using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenCredential;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Http::HttpMethod;
using Azure::Core::Http::HttpStatusCode;
using Azure::Core::Http::RawResponse;
using Azure::Core::Http::Request;
using Azure::Core::Http::_internal::HttpPipeline;
using Azure::Core::Http::Policies::HttpPolicy;
using Azure::Core::Http::Policies::NextHttpPolicy;
namespace {
class TestTokenCredential final : public Azure::Core::Credentials::TokenCredential {
class TestTokenCredential final : public TokenCredential {
private:
std::shared_ptr<Azure::Core::Credentials::AccessToken const> m_accessToken;
std::shared_ptr<AccessToken const> m_accessToken;
public:
explicit TestTokenCredential(
std::shared_ptr<Azure::Core::Credentials::AccessToken const> accessToken)
explicit TestTokenCredential(std::shared_ptr<AccessToken const> accessToken)
: TokenCredential("TestTokenCredential"), m_accessToken(accessToken)
{
}
Azure::Core::Credentials::AccessToken GetToken(
Azure::Core::Credentials::TokenRequestContext const&,
Azure::Core::Context const&) const override
AccessToken GetToken(TokenRequestContext const&, Context const&) const override
{
return *m_accessToken;
}
};
class TestTransportPolicy final : public Azure::Core::Http::Policies::HttpPolicy {
class TestTransportPolicy final : public HttpPolicy {
public:
std::unique_ptr<Azure::Core::Http::RawResponse> Send(
Azure::Core::Http::Request&,
Azure::Core::Http::Policies::NextHttpPolicy,
Azure::Core::Context const&) const override
std::unique_ptr<RawResponse> Send(Request&, NextHttpPolicy, Context const&) const override
{
return nullptr;
return std::make_unique<RawResponse>(1, 1, HttpStatusCode::Ok, "TestStatus");
}
std::unique_ptr<HttpPolicy> Clone() const override
@ -43,34 +53,31 @@ public:
return std::make_unique<TestTransportPolicy>(*this);
}
};
} // namespace
TEST(BearerTokenAuthenticationPolicy, InitialGet)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<Azure::Core::Credentials::AccessToken>();
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> policies;
std::vector<std::unique_ptr<HttpPolicy>> policies;
Azure::Core::Credentials::TokenRequestContext tokenRequestContext;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(
std::make_unique<Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
Azure::Core::Http::_internal::HttpPipeline pipeline(policies);
HttpPipeline pipeline(policies);
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
@ -84,37 +91,34 @@ TEST(BearerTokenAuthenticationPolicy, InitialGet)
TEST(BearerTokenAuthenticationPolicy, ReuseWhileValid)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<Azure::Core::Credentials::AccessToken>();
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> policies;
std::vector<std::unique_ptr<HttpPolicy>> policies;
Azure::Core::Credentials::TokenRequestContext tokenRequestContext;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(
std::make_unique<Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
Azure::Core::Http::_internal::HttpPipeline pipeline(policies);
HttpPipeline pipeline(policies);
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 5min};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
}
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
@ -128,37 +132,34 @@ TEST(BearerTokenAuthenticationPolicy, ReuseWhileValid)
TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<Azure::Core::Credentials::AccessToken>();
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> policies;
std::vector<std::unique_ptr<HttpPolicy>> policies;
Azure::Core::Credentials::TokenRequestContext tokenRequestContext;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(
std::make_unique<Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
Azure::Core::Http::_internal::HttpPipeline pipeline(policies);
HttpPipeline pipeline(policies);
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now() + 2min};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
}
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
@ -172,37 +173,34 @@ TEST(BearerTokenAuthenticationPolicy, RefreshNearExpiry)
TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<Azure::Core::Credentials::AccessToken>();
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> policies;
std::vector<std::unique_ptr<HttpPolicy>> policies;
Azure::Core::Credentials::TokenRequestContext tokenRequestContext;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(
std::make_unique<Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
Azure::Core::Http::_internal::HttpPipeline pipeline(policies);
HttpPipeline pipeline(policies);
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now()};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
}
{
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("https://www.azure.com"));
Request request(HttpMethod::Get, Url("https://www.azure.com"));
*accessToken = {"ACCESSTOKEN2", std::chrono::system_clock::now() + 1h};
pipeline.Send(request, Azure::Core::Context());
pipeline.Send(request, Context());
{
auto const headers = request.GetHeaders();
@ -213,30 +211,251 @@ TEST(BearerTokenAuthenticationPolicy, RefreshAfterExpiry)
}
}
TEST(BearerTokenAuthenticationPolicy, HttpEndpoint)
TEST(BearerTokenAuthenticationPolicy, NonHttps)
{
using namespace std::chrono_literals;
auto accessToken = std::make_shared<Azure::Core::Credentials::AccessToken>();
auto accessToken = std::make_shared<AccessToken>();
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> policies;
std::vector<std::unique_ptr<HttpPolicy>> policies;
Azure::Core::Credentials::TokenRequestContext tokenRequestContext;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(
std::make_unique<Azure::Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<BearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredential>(accessToken), tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
Azure::Core::Http::_internal::HttpPipeline pipeline(policies);
HttpPipeline pipeline(policies);
Azure::Core::Http::Request request(
Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("http://www.azure.com"));
Request request(HttpMethod::Get, Url("http://www.azure.com"));
*accessToken = {"ACCESSTOKEN1", std::chrono::system_clock::now()};
EXPECT_THROW(
static_cast<void>(pipeline.Send(request, Azure::Core::Context())),
Azure::Core::Credentials::AuthenticationException);
EXPECT_THROW(static_cast<void>(pipeline.Send(request, Context())), AuthenticationException);
}
namespace {
class TestBearerTokenAuthenticationPolicy final : public BearerTokenAuthenticationPolicy {
public:
TestBearerTokenAuthenticationPolicy(
std::shared_ptr<TokenCredential const> credential,
TokenRequestContext tokenRequestContext)
: BearerTokenAuthenticationPolicy(credential, tokenRequestContext)
{
}
std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<TestBearerTokenAuthenticationPolicy>(*this);
}
protected:
bool AuthorizeRequestOnChallenge(std::string const&, Request&, Context const&) const override
{
EXPECT_FALSE("AuthorizeRequestOnChallenge() should not get called if AuthorizeAndSendRequest() "
"was successful.");
return false;
}
};
class TestTokenCredentialForBearerTokenAuthenticationPolicy final : public TokenCredential {
public:
explicit TestTokenCredentialForBearerTokenAuthenticationPolicy()
: TokenCredential("TestTokenCredentialForBearerTokenAuthenticationPolicy")
{
}
AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const&)
const override
{
EXPECT_EQ(tokenRequestContext.Scopes.size(), 1);
EXPECT_EQ(tokenRequestContext.Scopes[0], "https://microsoft.com/.default");
EXPECT_TRUE(tokenRequestContext.TenantId.empty());
AccessToken result;
result.Token = "ACCESSTOKEN";
return result;
}
};
class TestChallengeBasedAuthenticationPolicy final : public BearerTokenAuthenticationPolicy {
private:
bool m_successfulAuthOnChallenge;
public:
TestChallengeBasedAuthenticationPolicy(
std::shared_ptr<TokenCredential const> credential,
TokenRequestContext tokenRequestContext,
bool successfulAuthOnChallenge)
: BearerTokenAuthenticationPolicy(credential, tokenRequestContext),
m_successfulAuthOnChallenge(successfulAuthOnChallenge)
{
}
std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<TestChallengeBasedAuthenticationPolicy>(*this);
}
protected:
std::unique_ptr<RawResponse> AuthorizeAndSendRequest(
Request& request,
NextHttpPolicy& nextHttpPolicy,
Context const& context) const override
{
EXPECT_EQ(request.GetUrl().GetAbsoluteUrl(), "https://www.azure.com");
TokenRequestContext trc;
trc.Scopes = {"https://visualstudio.com/.default"};
trc.TenantId = "TestTenantId1";
AuthenticateAndAuthorizeRequest(request, trc, context);
// Simulate as if we got challenge-based response and not a HTTP 200.
static_cast<void>(nextHttpPolicy.Send(request, context));
auto response = std::make_unique<RawResponse>(1, 1, HttpStatusCode::Unauthorized, "TestStatus");
response->SetHeader("WWW-Authenticate", "TestChallenge");
return response;
}
bool AuthorizeRequestOnChallenge(
std::string const& challenge,
Request& request,
Context const& context) const override
{
EXPECT_EQ(challenge, "TestChallenge");
TokenRequestContext trc;
trc.Scopes = {"https://xbox.com/.default"};
trc.TenantId = "TestTenantId2";
if (m_successfulAuthOnChallenge)
{
AuthenticateAndAuthorizeRequest(request, trc, context);
return true;
}
return false;
}
};
class TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy final : public TokenCredential {
private:
mutable int m_invokedTimes;
public:
explicit TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy()
: TokenCredential("TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy"),
m_invokedTimes(0)
{
}
AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const&)
const override
{
++m_invokedTimes;
EXPECT_GE(m_invokedTimes, 1);
EXPECT_LE(m_invokedTimes, 2);
if (m_invokedTimes == 1)
{
EXPECT_EQ(tokenRequestContext.Scopes.size(), 1);
EXPECT_EQ(tokenRequestContext.Scopes[0], "https://visualstudio.com/.default");
EXPECT_EQ(tokenRequestContext.TenantId, "TestTenantId1");
AccessToken result;
result.Token = "ACCESSTOKEN1";
return result;
}
EXPECT_EQ(tokenRequestContext.Scopes.size(), 1);
EXPECT_EQ(tokenRequestContext.Scopes[0], "https://xbox.com/.default");
EXPECT_EQ(tokenRequestContext.TenantId, "TestTenantId2");
AccessToken result;
result.Token = "ACCESSTOKEN2";
return result;
}
};
} // namespace
TEST(BearerTokenAuthenticationPolicy, ChallengeBasedSupport)
{
std::vector<std::unique_ptr<HttpPolicy>> policies;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(std::make_unique<TestBearerTokenAuthenticationPolicy>(
std::make_shared<TestTokenCredentialForBearerTokenAuthenticationPolicy>(),
tokenRequestContext));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
HttpPipeline pipeline(policies);
Request request(HttpMethod::Get, Url("https://www.azure.com"));
static_cast<void>(pipeline.Send(request, Context()));
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN");
}
TEST(BearerTokenAuthenticationPolicy, ChallengeBasedSuccess)
{
std::vector<std::unique_ptr<HttpPolicy>> policies;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(std::make_unique<TestChallengeBasedAuthenticationPolicy>(
std::make_shared<TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy>(),
tokenRequestContext,
true));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
HttpPipeline pipeline(policies);
Request request(HttpMethod::Get, Url("https://www.azure.com"));
auto const response = pipeline.Send(request, Context());
EXPECT_EQ(response->GetStatusCode(), HttpStatusCode::Ok);
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN2");
}
TEST(BearerTokenAuthenticationPolicy, ChallengeBasedFailure)
{
std::vector<std::unique_ptr<HttpPolicy>> policies;
TokenRequestContext tokenRequestContext;
tokenRequestContext.Scopes = {"https://microsoft.com/.default"};
policies.emplace_back(std::make_unique<TestChallengeBasedAuthenticationPolicy>(
std::make_shared<TestTokenCredentialForChallengeBasedTokenAuthenticationPolicy>(),
tokenRequestContext,
false));
policies.emplace_back(std::make_unique<TestTransportPolicy>());
HttpPipeline pipeline(policies);
Request request(HttpMethod::Get, Url("https://www.azure.com"));
auto const response = pipeline.Send(request, Context());
EXPECT_EQ(response->GetStatusCode(), HttpStatusCode::Unauthorized);
auto const headers = request.GetHeaders();
auto const authHeader = headers.find("authorization");
EXPECT_NE(authHeader, headers.end());
EXPECT_EQ(authHeader->second, "Bearer ACCESSTOKEN1");
}

View File

@ -4,6 +4,8 @@
### Features Added
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes
### Bugs Fixed

View File

@ -66,6 +66,7 @@ set(
src/private/identity_log.hpp
src/private/managed_identity_source.hpp
src/private/package_version.hpp
src/private/tenant_id_resolver.hpp
src/private/token_credential_impl.hpp
src/azure_cli_credential.cpp
src/chained_token_credential.cpp
@ -76,6 +77,7 @@ set(
src/environment_credential.cpp
src/managed_identity_credential.cpp
src/managed_identity_source.cpp
src/tenant_id_resolver.cpp
src/token_cache.cpp
src/token_credential_impl.cpp
)

View File

@ -17,6 +17,7 @@
#include <chrono>
#include <string>
#include <vector>
namespace Azure { namespace Identity {
/**
@ -36,6 +37,13 @@ namespace Azure { namespace Identity {
*/
DateTime::duration CliProcessTimeout
= std::chrono::seconds(13); // Value was taken from .NET SDK.
/**
* @brief For multi-tenant applications, specifies additional tenants for which the credential
* may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens
* for any tenant in which the application is installed.
*/
std::vector<std::string> AdditionallyAllowedTenants;
};
/**
@ -49,14 +57,16 @@ namespace Azure { namespace Identity {
: public Core::Credentials::TokenCredential {
protected:
_detail::TokenCache m_tokenCache;
std::vector<std::string> m_additionallyAllowedTenants;
std::string m_tenantId;
DateTime::duration m_cliProcessTimeout;
private:
explicit AzureCliCredential(
Core::Credentials::TokenCredentialOptions const& options,
std::string tenantId,
DateTime::duration cliProcessTimeout,
Core::Credentials::TokenCredentialOptions const& options);
std::vector<std::string> additionallyAllowedTenants);
void ThrowIfNotSafeCmdLineInput(std::string const& input, std::string const& description) const;

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
namespace Azure { namespace Identity {
namespace _detail {
@ -50,6 +51,13 @@ namespace Azure { namespace Identity {
* https://docs.microsoft.com/azure/active-directory/develop/authentication-national-cloud.
*/
std::string AuthorityHost = _detail::ClientCredentialCore::AadGlobalAuthority;
/**
* @brief For multi-tenant applications, specifies additional tenants for which the credential
* may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens
* for any tenant in which the application is installed.
*/
std::vector<std::string> AdditionallyAllowedTenants;
};
/**
@ -72,6 +80,7 @@ namespace Azure { namespace Identity {
std::string const& clientId,
std::string const& clientCertificatePath,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants,
Core::Credentials::TokenCredentialOptions const& options);
public:

View File

@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <vector>
namespace Azure { namespace Identity {
namespace _detail {
@ -38,6 +39,13 @@ namespace Azure { namespace Identity {
* https://docs.microsoft.com/azure/active-directory/develop/authentication-national-cloud.
*/
std::string AuthorityHost = _detail::ClientCredentialCore::AadGlobalAuthority;
/**
* @brief For multi-tenant applications, specifies additional tenants for which the credential
* may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens
* for any tenant in which the application is installed.
*/
std::vector<std::string> AdditionallyAllowedTenants;
};
/**
@ -57,6 +65,7 @@ namespace Azure { namespace Identity {
std::string const& clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants,
Core::Credentials::TokenCredentialOptions const& options);
public:

View File

@ -9,21 +9,33 @@
#include <azure/core/url.hpp>
#include <string>
#include <vector>
namespace Azure { namespace Identity { namespace _detail {
class ClientCredentialCore final {
std::vector<std::string> m_additionallyAllowedTenants;
Core::Url m_authorityHost;
std::string m_tenantId;
bool m_isAdfs;
public:
AZ_IDENTITY_DLLEXPORT static std::string const AadGlobalAuthority;
explicit ClientCredentialCore(std::string tenantId, std::string const& authorityHost);
explicit ClientCredentialCore(
std::string tenantId,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants);
Core::Url GetRequestUrl() const;
Core::Url GetRequestUrl(std::string const& tenantId) const;
std::string GetScopesString(decltype(Core::Credentials::TokenRequestContext::Scopes)
const& scopes) const;
std::string GetScopesString(
std::string const& tenantId,
decltype(Core::Credentials::TokenRequestContext::Scopes) const& scopes) const;
std::string const& GetTenantId() const { return m_tenantId; }
std::vector<std::string> const& GetAdditionallyAllowedTenants() const
{
return m_additionallyAllowedTenants;
}
};
}}} // namespace Azure::Identity::_detail

View File

@ -17,6 +17,7 @@
#include <memory>
#include <shared_mutex>
#include <string>
#include <tuple>
namespace Azure { namespace Identity { namespace _detail {
/**
@ -39,26 +40,27 @@ namespace Azure { namespace Identity { namespace _detail {
// A test hook that gets invoked before item write lock gets acquired.
virtual void OnBeforeItemWriteLock() const {};
struct CacheKey
{
std::string Scope;
std::string TenantId;
};
struct CacheKeyComparator
{
bool operator()(CacheKey const& lhs, CacheKey const& rhs) const
{
return std::tie(lhs.Scope, lhs.TenantId) < std::tie(rhs.Scope, rhs.TenantId);
}
};
struct CacheValue
{
Core::Credentials::AccessToken AccessToken;
std::shared_timed_mutex ElementMutex;
};
// The current cache Key, std::string Scopes, may later evolve to a struct that contains more
// fields. All that depends on the fields in the TokenRequestContext that are used as
// characteristics that go into the network request that gets the token.
// If tomorrow we add Multi-Tenant Authentication, and the TenantID stops being an immutable
// characteristic of a credential instance, but instead becomes variable depending on the fields
// of the TokenRequestContext that are taken into consideration as network request for the token
// is being sent, it should go into what will form the new CacheKey struct.
// i.e. we want all the variable inputs for obtaining a token to be a part of the key, because
// we want to have the same kind of result. There should be no "hidden variables".
// Otherwise, the cache will stop functioning properly, because the value you'd get from cache
// for a given key will fail to authenticate, but if the cache ends up calling the getNewToken
// callback, you'll authenticate successfully (however the other caller who need to get the
// token for slightly different context will not be as lucky).
mutable std::map<std::string, std::shared_ptr<CacheValue>> m_cache;
mutable std::map<CacheKey, std::shared_ptr<CacheValue>, CacheKeyComparator> m_cache;
mutable std::shared_timed_mutex m_cacheMutex;
private:
@ -73,7 +75,7 @@ namespace Azure { namespace Identity { namespace _detail {
// Gets item from cache, or creates it, puts into cache, and returns.
std::shared_ptr<CacheValue> GetOrCreateValue(
std::string const& key,
CacheKey const& key,
DateTime::duration minimumExpiration) const;
public:
@ -85,6 +87,7 @@ namespace Azure { namespace Identity { namespace _detail {
* provided, caches it, and returns its value.
*
* @param scopeString Authentication scopes (or resource) as string.
* @param tenantId TenantId for authentication.
* @param minimumExpiration Minimum token lifetime for the cached value to be returned.
* @param getNewToken Function to get the new token for the given \p scopeString, in case when
* cache does not have it, or if its remaining lifetime is less than \p minimumExpiration.
@ -94,6 +97,7 @@ namespace Azure { namespace Identity { namespace _detail {
*/
Core::Credentials::AccessToken GetToken(
std::string const& scopeString,
std::string const& tenantId,
DateTime::duration minimumExpiration,
std::function<Core::Credentials::AccessToken()> const& getNewToken) const;
};

View File

@ -12,37 +12,61 @@
#include <azure/core/credentials/token_credential_options.hpp>
#include <memory>
#include <string>
#include <vector>
namespace Azure { namespace Identity {
/**
* @brief Options for token authentication.
*
*/
struct EnvironmentCredentialOptions final : public Core::Credentials::TokenCredentialOptions
{
/**
* @brief For multi-tenant applications, specifies additional tenants for which the credential
* may acquire tokens. Add the wildcard value `"*"` to allow the credential to acquire tokens
* for any tenant in which the application is installed.
*/
std::vector<std::string> AdditionallyAllowedTenants;
};
/**
* @brief Environment Credential initializes an Azure credential, based on the system environment
* variables being set.
*
* @note May read from the following environment variables:
* - `AZURE_TENANT_ID`
* - `AZURE_CLIENT_ID`
* - `AZURE_CLIENT_SECRET`
* - `AZURE_CLIENT_CERTIFICATE_PATH`
* - `AZURE_CLIENT_CERTIFICATE_PASSWORD`
* - `AZURE_CLIENT_SEND_CERTIFICATE_CHAIN`
* - `AZURE_USERNAME`
* - `AZURE_PASSWORD`
* - `AZURE_AUTHORITY_HOST`
*/
class EnvironmentCredential final : public Core::Credentials::TokenCredential {
private:
std::unique_ptr<TokenCredential> m_credentialImpl;
explicit EnvironmentCredential(
Core::Credentials::TokenCredentialOptions const& options,
std::vector<std::string> const& additionallyAllowedTenants);
public:
/**
* @brief Constructs an Environment Credential.
*
* @param options Options for token retrieval.
*
* @note May read from the following environment variables:
* - AZURE_TENANT_ID
* - AZURE_CLIENT_ID
* - AZURE_CLIENT_SECRET
* - AZURE_CLIENT_CERTIFICATE_PATH
* - AZURE_CLIENT_CERTIFICATE_PASSWORD
* - AZURE_CLIENT_SEND_CERTIFICATE_CHAIN
* - AZURE_USERNAME
* - AZURE_PASSWORD
* - AZURE_AUTHORITY_HOST
*/
explicit EnvironmentCredential(
Azure::Core::Credentials::TokenCredentialOptions options
= Azure::Core::Credentials::TokenCredentialOptions());
Core::Credentials::TokenCredentialOptions const& options
= Core::Credentials::TokenCredentialOptions());
/**
* @brief Constructs an Environment Credential.
* @param options Options for token retrieval.
*/
explicit EnvironmentCredential(EnvironmentCredentialOptions const& options);
/**
* @brief Gets an authentication token.

View File

@ -4,6 +4,7 @@
#include "azure/identity/azure_cli_credential.hpp"
#include "private/identity_log.hpp"
#include "private/tenant_id_resolver.hpp"
#include "private/token_credential_impl.hpp"
#include <azure/core/internal/environment.hpp>
@ -47,6 +48,7 @@ using Azure::Core::Credentials::TokenCredentialOptions;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Identity::AzureCliCredentialOptions;
using Azure::Identity::_detail::IdentityLog;
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Identity::_detail::TokenCache;
using Azure::Identity::_detail::TokenCredentialImpl;
@ -77,11 +79,13 @@ void AzureCliCredential::ThrowIfNotSafeCmdLineInput(
}
}
AzureCliCredential::AzureCliCredential(
Core::Credentials::TokenCredentialOptions const& options,
std::string tenantId,
DateTime::duration cliProcessTimeout,
Core::Credentials::TokenCredentialOptions const& options)
: TokenCredential("AzureCliCredential"), m_tenantId(std::move(tenantId)),
m_cliProcessTimeout(std::move(cliProcessTimeout))
std::vector<std::string> additionallyAllowedTenants)
: TokenCredential("AzureCliCredential"),
m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)),
m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout))
{
static_cast<void>(options);
@ -95,15 +99,20 @@ AzureCliCredential::AzureCliCredential(
}
AzureCliCredential::AzureCliCredential(AzureCliCredentialOptions const& options)
: AzureCliCredential(options.TenantId, options.CliProcessTimeout, options)
: AzureCliCredential(
options,
options.TenantId,
options.CliProcessTimeout,
options.AdditionallyAllowedTenants)
{
}
AzureCliCredential::AzureCliCredential(TokenCredentialOptions const& options)
: AzureCliCredential(
options,
AzureCliCredentialOptions{}.TenantId,
AzureCliCredentialOptions{}.CliProcessTimeout,
options)
AzureCliCredentialOptions{}.AdditionallyAllowedTenants)
{
}
@ -133,15 +142,17 @@ AccessToken AzureCliCredential::GetToken(
Context const& context) const
{
auto const scopes = TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, false, false);
auto const tenantId
= TenantIdResolver::Resolve(m_tenantId, tokenRequestContext, m_additionallyAllowedTenants);
// TokenCache::GetToken() can only use the lambda argument when they are being executed. They are
// not supposed to keep a reference to lambda argument to call it later. Therefore, any capture
// made here will outlive the possible time frame when the lambda might get called.
return m_tokenCache.GetToken(scopes, tokenRequestContext.MinimumExpiration, [&]() {
// TokenCache::GetToken() can only use the lambda argument when they are being executed. They
// are not supposed to keep a reference to lambda argument to call it later. Therefore, any
// capture made here will outlive the possible time frame when the lambda might get called.
return m_tokenCache.GetToken(scopes, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
try
{
auto const azCliResult
= RunShellCommand(GetAzCommand(scopes, m_tenantId), m_cliProcessTimeout, context);
= RunShellCommand(GetAzCommand(scopes, tenantId), m_cliProcessTimeout, context);
try
{

View File

@ -3,6 +3,7 @@
#include "azure/identity/client_certificate_credential.hpp"
#include "private/tenant_id_resolver.hpp"
#include "private/token_credential_impl.hpp"
#include <azure/core/base64.hpp>
@ -33,6 +34,7 @@ using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenCredentialOptions;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Http::HttpMethod;
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Identity::_detail::TokenCredentialImpl;
namespace {
@ -79,9 +81,10 @@ ClientCertificateCredential::ClientCertificateCredential(
std::string const& clientId,
std::string const& clientCertificatePath,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants,
TokenCredentialOptions const& options)
: TokenCredential("ClientCertificateCredential"),
m_clientCredentialCore(tenantId, authorityHost),
m_clientCredentialCore(tenantId, authorityHost, additionallyAllowedTenants),
m_tokenCredentialImpl(std::make_unique<TokenCredentialImpl>(options)),
m_requestBody(
std::string(
@ -178,6 +181,7 @@ ClientCertificateCredential::ClientCertificateCredential(
clientId,
clientCertificatePath,
options.AuthorityHost,
options.AdditionallyAllowedTenants,
options)
{
}
@ -192,6 +196,7 @@ ClientCertificateCredential::ClientCertificateCredential(
clientId,
clientCertificatePath,
ClientCertificateCredentialOptions{}.AuthorityHost,
ClientCertificateCredentialOptions{}.AdditionallyAllowedTenants,
options)
{
}
@ -202,13 +207,19 @@ AccessToken ClientCertificateCredential::GetToken(
TokenRequestContext const& tokenRequestContext,
Context const& context) const
{
auto const scopesStr = m_clientCredentialCore.GetScopesString(tokenRequestContext.Scopes);
auto const tenantId = TenantIdResolver::Resolve(
m_clientCredentialCore.GetTenantId(),
tokenRequestContext,
m_clientCredentialCore.GetAdditionallyAllowedTenants());
auto const scopesStr
= m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes);
// TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
auto body = m_requestBody;
if (!scopesStr.empty())
@ -216,7 +227,7 @@ AccessToken ClientCertificateCredential::GetToken(
body += "&scope=" + scopesStr;
}
auto const requestUrl = m_clientCredentialCore.GetRequestUrl();
auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId);
std::string assertion = m_tokenHeaderEncoded;
{

View File

@ -3,6 +3,7 @@
#include "azure/identity/detail/client_credential_core.hpp"
#include "private/tenant_id_resolver.hpp"
#include "private/token_credential_impl.hpp"
#include <utility>
@ -11,29 +12,35 @@ using Azure::Identity::_detail::ClientCredentialCore;
using Azure::Core::Url;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Identity::_detail::TokenCredentialImpl;
decltype(ClientCredentialCore::AadGlobalAuthority) ClientCredentialCore::AadGlobalAuthority
= "https://login.microsoftonline.com/";
ClientCredentialCore::ClientCredentialCore(std::string tenantId, std::string const& authorityHost)
: m_authorityHost(Url(authorityHost)), m_tenantId(std::move(tenantId))
ClientCredentialCore::ClientCredentialCore(
std::string tenantId,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants)
: m_additionallyAllowedTenants(std::move(additionallyAllowedTenants)),
m_authorityHost(Url(authorityHost)), m_tenantId(std::move(tenantId))
{
// ADFS is the Active Directory Federation Service, a tenant ID that is used in Azure Stack.
m_isAdfs = m_tenantId == "adfs";
}
Url ClientCredentialCore::GetRequestUrl() const
Url ClientCredentialCore::GetRequestUrl(std::string const& tenantId) const
{
auto requestUrl = m_authorityHost;
requestUrl.AppendPath(m_tenantId);
requestUrl.AppendPath(m_isAdfs ? "oauth2/token" : "oauth2/v2.0/token");
requestUrl.AppendPath(tenantId);
requestUrl.AppendPath(TenantIdResolver::IsAdfs(tenantId) ? "oauth2/token" : "oauth2/v2.0/token");
return requestUrl;
}
std::string ClientCredentialCore::GetScopesString(decltype(TokenRequestContext::Scopes)
const& scopes) const
std::string ClientCredentialCore::GetScopesString(
std::string const& tenantId,
decltype(TokenRequestContext::Scopes) const& scopes) const
{
return scopes.empty() ? std::string() : TokenCredentialImpl::FormatScopes(scopes, m_isAdfs);
return scopes.empty()
? std::string()
: TokenCredentialImpl::FormatScopes(scopes, TenantIdResolver::IsAdfs(tenantId));
}

View File

@ -3,6 +3,7 @@
#include "azure/identity/client_secret_credential.hpp"
#include "private/tenant_id_resolver.hpp"
#include "private/token_credential_impl.hpp"
using Azure::Identity::ClientSecretCredential;
@ -13,6 +14,7 @@ using Azure::Core::Credentials::AccessToken;
using Azure::Core::Credentials::TokenCredentialOptions;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Core::Http::HttpMethod;
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Identity::_detail::TokenCredentialImpl;
ClientSecretCredential::ClientSecretCredential(
@ -20,8 +22,10 @@ ClientSecretCredential::ClientSecretCredential(
std::string const& clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::vector<std::string> additionallyAllowedTenants,
TokenCredentialOptions const& options)
: TokenCredential("ClientSecretCredential"), m_clientCredentialCore(tenantId, authorityHost),
: TokenCredential("ClientSecretCredential"),
m_clientCredentialCore(tenantId, authorityHost, additionallyAllowedTenants),
m_tokenCredentialImpl(std::make_unique<TokenCredentialImpl>(options)),
m_requestBody(
std::string("grant_type=client_credentials&client_id=") + Url::Encode(clientId)
@ -34,7 +38,13 @@ ClientSecretCredential::ClientSecretCredential(
std::string const& clientId,
std::string const& clientSecret,
ClientSecretCredentialOptions const& options)
: ClientSecretCredential(tenantId, clientId, clientSecret, options.AuthorityHost, options)
: ClientSecretCredential(
tenantId,
clientId,
clientSecret,
options.AuthorityHost,
options.AdditionallyAllowedTenants,
options)
{
}
@ -48,6 +58,7 @@ ClientSecretCredential::ClientSecretCredential(
clientId,
clientSecret,
ClientSecretCredentialOptions{}.AuthorityHost,
ClientSecretCredentialOptions{}.AdditionallyAllowedTenants,
options)
{
}
@ -58,13 +69,20 @@ AccessToken ClientSecretCredential::GetToken(
TokenRequestContext const& tokenRequestContext,
Context const& context) const
{
auto const scopesStr = m_clientCredentialCore.GetScopesString(tokenRequestContext.Scopes);
auto const tenantId = TenantIdResolver::Resolve(
m_clientCredentialCore.GetTenantId(),
tokenRequestContext,
m_clientCredentialCore.GetAdditionallyAllowedTenants());
auto const scopesStr
= m_clientCredentialCore.GetScopesString(tenantId, tokenRequestContext.Scopes);
// TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, tenantId, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCredentialImpl->GetToken(context, [&]() {
auto body = m_requestBody;
@ -73,7 +91,7 @@ AccessToken ClientSecretCredential::GetToken(
body += "&scope=" + scopesStr;
}
auto const requestUrl = m_clientCredentialCore.GetRequestUrl();
auto const requestUrl = m_clientCredentialCore.GetRequestUrl(tenantId);
auto request
= std::make_unique<TokenCredentialImpl::TokenRequest>(HttpMethod::Post, requestUrl, body);

View File

@ -14,6 +14,7 @@
#include <vector>
using Azure::Identity::EnvironmentCredential;
using Azure::Identity::EnvironmentCredentialOptions;
using Azure::Core::Context;
using Azure::Core::_internal::Environment;
@ -36,7 +37,9 @@ void PrintCredentialCreationLogMessage(
char const* credThatGetsCreated);
} // namespace
EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options)
EnvironmentCredential::EnvironmentCredential(
TokenCredentialOptions const& options,
std::vector<std::string> const& additionallyAllowedTenants)
: TokenCredential("EnvironmentCredential")
{
auto tenantId = Environment::GetVariable(AzureTenantIdEnvVarName);
@ -49,77 +52,48 @@ EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options)
if (!tenantId.empty() && !clientId.empty())
{
std::vector<std::pair<char const*, char const*>> envVarsToParams
= {{AzureTenantIdEnvVarName, "tenantId"}, {AzureClientIdEnvVarName, "clientId"}};
if (!clientSecret.empty())
{
envVarsToParams.push_back({AzureClientSecretEnvVarName, "clientSecret"});
ClientSecretCredentialOptions clientSecretCredentialOptions;
static_cast<TokenCredentialOptions&>(clientSecretCredentialOptions) = options;
clientSecretCredentialOptions.AdditionallyAllowedTenants = additionallyAllowedTenants;
if (!authority.empty())
{
PrintCredentialCreationLogMessage(
GetCredentialName(),
{
{AzureTenantIdEnvVarName, "tenantId"},
{AzureClientIdEnvVarName, "clientId"},
{AzureClientSecretEnvVarName, "clientSecret"},
{AzureAuthorityHostEnvVarName, "authorityHost"},
},
"ClientSecretCredential");
ClientSecretCredentialOptions clientSecretCredentialOptions;
static_cast<TokenCredentialOptions&>(clientSecretCredentialOptions) = options;
envVarsToParams.push_back({AzureAuthorityHostEnvVarName, "authorityHost"});
clientSecretCredentialOptions.AuthorityHost = authority;
m_credentialImpl.reset(new ClientSecretCredential(
tenantId, clientId, clientSecret, clientSecretCredentialOptions));
}
else
{
PrintCredentialCreationLogMessage(
GetCredentialName(),
{
{AzureTenantIdEnvVarName, "tenantId"},
{AzureClientIdEnvVarName, "clientId"},
{AzureClientSecretEnvVarName, "clientSecret"},
},
"ClientSecretCredential");
m_credentialImpl.reset(
new ClientSecretCredential(tenantId, clientId, clientSecret, options));
}
PrintCredentialCreationLogMessage(
GetCredentialName(), envVarsToParams, "ClientSecretCredential");
m_credentialImpl.reset(new ClientSecretCredential(
tenantId, clientId, clientSecret, clientSecretCredentialOptions));
}
else if (!clientCertificatePath.empty())
{
envVarsToParams.push_back({AzureClientCertificatePathEnvVarName, "clientCertificatePath"});
ClientCertificateCredentialOptions clientCertificateCredentialOptions;
static_cast<TokenCredentialOptions&>(clientCertificateCredentialOptions) = options;
clientCertificateCredentialOptions.AdditionallyAllowedTenants = additionallyAllowedTenants;
if (!authority.empty())
{
PrintCredentialCreationLogMessage(
GetCredentialName(),
{
{AzureTenantIdEnvVarName, "tenantId"},
{AzureClientIdEnvVarName, "clientId"},
{AzureClientCertificatePathEnvVarName, "clientCertificatePath"},
{AzureAuthorityHostEnvVarName, "authorityHost"},
},
"ClientCertificateCredential");
ClientCertificateCredentialOptions clientCertificateCredentialOptions;
static_cast<TokenCredentialOptions&>(clientCertificateCredentialOptions) = options;
envVarsToParams.push_back({AzureAuthorityHostEnvVarName, "authorityHost"});
clientCertificateCredentialOptions.AuthorityHost = authority;
m_credentialImpl.reset(new ClientCertificateCredential(
tenantId, clientId, clientCertificatePath, clientCertificateCredentialOptions));
}
else
{
PrintCredentialCreationLogMessage(
GetCredentialName(),
{
{AzureTenantIdEnvVarName, "tenantId"},
{AzureClientIdEnvVarName, "clientId"},
{AzureClientCertificatePathEnvVarName, "clientCertificatePath"},
},
"ClientCertificateCredential");
m_credentialImpl.reset(
new ClientCertificateCredential(tenantId, clientId, clientCertificatePath, options));
}
PrintCredentialCreationLogMessage(
GetCredentialName(), envVarsToParams, "ClientCertificateCredential");
m_credentialImpl.reset(new ClientCertificateCredential(
tenantId, clientId, clientCertificatePath, clientCertificateCredentialOptions));
}
}
@ -156,6 +130,16 @@ EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions options)
}
}
EnvironmentCredential::EnvironmentCredential(TokenCredentialOptions const& options)
: EnvironmentCredential(options, {})
{
}
EnvironmentCredential::EnvironmentCredential(EnvironmentCredentialOptions const& options)
: EnvironmentCredential(options, options.AdditionallyAllowedTenants)
{
}
AccessToken EnvironmentCredential::GetToken(
TokenRequestContext const& tokenRequestContext,
Context const& context) const

View File

@ -136,7 +136,7 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken(
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);
@ -218,7 +218,7 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken(
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
using Azure::Core::Url;
using Azure::Core::Http::HttpMethod;
@ -317,7 +317,7 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken(
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(
context,
createRequest,
@ -417,7 +417,7 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken(
// when they are being executed. They are not supposed to keep a reference to lambda argument to
// call it later. Therefore, any capture made here will outlive the possible time frame when the
// lambda might get called.
return m_tokenCache.GetToken(scopesStr, tokenRequestContext.MinimumExpiration, [&]() {
return m_tokenCache.GetToken(scopesStr, {}, tokenRequestContext.MinimumExpiration, [&]() {
return TokenCredentialImpl::GetToken(context, [&]() {
auto request = std::make_unique<TokenRequest>(m_request);

View File

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#pragma once
#include <azure/core/credentials/credentials.hpp>
#include <string>
#include <vector>
namespace Azure { namespace Identity { namespace _detail {
/**
* @brief Implements an access token cache.
*
*/
class TenantIdResolver final {
TenantIdResolver() = delete;
~TenantIdResolver() = delete;
public:
static std::string Resolve(
std::string const& explicitTenantId,
Core::Credentials::TokenRequestContext const& tokenRequestContext,
std::vector<std::string> const& additionallyAllowedTenants);
// ADFS is the Active Directory Federation Service, a tenant ID that is used in Azure Stack.
static bool IsAdfs(std::string const& tenantId);
};
}}} // namespace Azure::Identity::_detail

View File

@ -1,62 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Token cache.
*/
#pragma once
#include <azure/core/credentials/credentials.hpp>
#include <azure/core/datetime.hpp>
#include <functional>
#include <string>
namespace Azure { namespace Identity { namespace _detail {
/**
* @brief Implements an access token cache.
*
*/
class TokenCache final {
TokenCache() = delete;
~TokenCache() = delete;
public:
/**
* @brief Attempts to get token from cache, and if not found, gets the token using the function
* provided, caches it, and returns its value.
*
* @param tenantId Azure Tenant ID.
* @param clientId Azure Client ID.
* @param authorityHost Authentication authority URL.
* @param scopes Authentication scopes.
*
* @return Authentication token.
*
* @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred.
*/
static Core::Credentials::AccessToken GetToken(
std::string const& tenantId,
std::string const& clientId,
std::string const& authorityHost,
std::string const& scopes,
DateTime::duration minimumExpiration,
std::function<Core::Credentials::AccessToken()> const& getNewToken);
/**
* @brief Provides access to internal aspects of the cache as a test hook.
*
*/
class Internals;
#if defined(TESTING_BUILD)
/**
* @brief Clears token cache. Intended to only be used in tests.
*
*/
static void Clear();
#endif
};
}}} // namespace Azure::Identity::_detail

View File

@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "private/tenant_id_resolver.hpp"
#include <azure/core/internal/environment.hpp>
#include <azure/core/internal/strings.hpp>
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Core::_internal::Environment;
using Azure::Core::_internal::StringExtensions;
using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenRequestContext;
namespace {
bool IsMultitenantAuthDisabled()
{
auto const envVar = Environment::GetVariable("AZURE_IDENTITY_DISABLE_MULTITENANTAUTH");
return envVar == "1" || StringExtensions::LocaleInvariantCaseInsensitiveEqual(envVar, "true");
}
} // namespace
std::string TenantIdResolver::Resolve(
std::string const& explicitTenantId,
TokenRequestContext const& tokenRequestContext,
std::vector<std::string> const& additionallyAllowedTenants)
{
auto const& requestedTenantId = tokenRequestContext.TenantId;
if (requestedTenantId.empty()
|| StringExtensions::LocaleInvariantCaseInsensitiveEqual(requestedTenantId, explicitTenantId)
|| IsAdfs(explicitTenantId) || IsMultitenantAuthDisabled())
{
return explicitTenantId;
}
for (auto const& allowedTenantId : additionallyAllowedTenants)
{
if (allowedTenantId == "*"
|| StringExtensions::LocaleInvariantCaseInsensitiveEqual(
allowedTenantId, requestedTenantId))
{
return requestedTenantId;
}
}
throw AuthenticationException(
"The current credential is not configured to acquire tokens for tenant '" + requestedTenantId
+ "'. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on "
"the credential options, or add \"*\" to AdditionallyAllowedTenants to allow acquiring "
"tokens for any tenant.");
}
bool TenantIdResolver::IsAdfs(std::string const& tenantId)
{
return StringExtensions::LocaleInvariantCaseInsensitiveEqual(tenantId, "adfs");
}

View File

@ -26,7 +26,7 @@ template <typename T> bool ShouldCleanUpCacheFromExpiredItems(T cacheSize);
}
std::shared_ptr<TokenCache::CacheValue> TokenCache::GetOrCreateValue(
std::string const& key,
CacheKey const& key,
DateTime::duration minimumExpiration) const
{
{
@ -86,10 +86,11 @@ std::shared_ptr<TokenCache::CacheValue> TokenCache::GetOrCreateValue(
AccessToken TokenCache::GetToken(
std::string const& scopeString,
std::string const& tenantId,
DateTime::duration minimumExpiration,
std::function<AccessToken()> const& getNewToken) const
{
auto const item = GetOrCreateValue(scopeString, minimumExpiration);
auto const item = GetOrCreateValue({scopeString, tenantId}, minimumExpiration);
{
std::shared_lock<std::shared_timed_mutex> itemReadLock(item->ElementMutex);

View File

@ -27,6 +27,7 @@ add_executable (
macro_guard_test.cpp
managed_identity_credential_test.cpp
simplified_header_test.cpp
tenant_id_resolver_test.cpp
token_cache_test.cpp
token_credential_impl_test.cpp
token_credential_test.cpp

View File

@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "private/tenant_id_resolver.hpp"
#include "credential_test_helper.hpp"
#include <gtest/gtest.h>
using Azure::Identity::_detail::TenantIdResolver;
using Azure::Core::Credentials::AuthenticationException;
using Azure::Core::Credentials::TokenRequestContext;
using Azure::Identity::Test::_detail::CredentialTestHelper;
TEST(TenantIdResolver, RequestedTenantIdEmpty)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", ""}, // Default, i.e. NOT disabled
});
auto const tenantId = TenantIdResolver::Resolve("aA", {}, {});
EXPECT_EQ(tenantId, "aA");
}
TEST(TenantIdResolver, RequestedTenantIdEqualsExplicitTenantId)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "0"}, // Default, i.e. NOT disabled
});
TokenRequestContext trc;
trc.TenantId = "Aa";
auto const tenantId = TenantIdResolver::Resolve("aA", trc, {});
EXPECT_EQ(tenantId, "aA");
}
TEST(TenantIdResolver, Adfs)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "false"}, // Default, i.e. NOT disabled
});
TokenRequestContext trc;
trc.TenantId = "bB";
auto const tenantId = TenantIdResolver::Resolve("aDfS", trc, {});
EXPECT_EQ(tenantId, "aDfS");
}
TEST(TenantIdResolver, Disabled1)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "1"}, // Should be DISABLED
});
TokenRequestContext trc;
trc.TenantId = "bB";
auto const tenantId = TenantIdResolver::Resolve("aA", trc, {});
EXPECT_EQ(tenantId, "aA");
}
TEST(TenantIdResolver, DisabledTrue)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "tRuE"}, // Should be DISABLED
});
TokenRequestContext trc;
trc.TenantId = "bB";
auto const tenantId = TenantIdResolver::Resolve("aA", trc, {});
EXPECT_EQ(tenantId, "aA");
}
TEST(TenantIdResolver, Wildcard)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "2"}, // Not a value that should be recognized
});
TokenRequestContext trc;
trc.TenantId = "bB";
auto const tenantId = TenantIdResolver::Resolve("aA", trc, {"cC", "*", "dD"});
EXPECT_EQ(tenantId, "bB");
}
TEST(TenantIdResolver, Match)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "T"}, // Not a value that should be recognized
});
TokenRequestContext trc;
trc.TenantId = "bB";
auto const tenantId = TenantIdResolver::Resolve("bA", trc, {"cC", "Bb", "dD"});
EXPECT_EQ(tenantId, "bB");
}
TEST(TenantIdResolver, NoMatch)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "yes"}, // Not a value that should be recognized
});
TokenRequestContext trc;
trc.TenantId = "bB";
EXPECT_THROW(
static_cast<void>(TenantIdResolver::Resolve("aA", trc, {"cC", "dD"})),
AuthenticationException);
}
TEST(TenantIdResolver, NoMatchEmpty)
{
CredentialTestHelper::EnvironmentOverride const env(std::map<std::string, std::string>{
{"AZURE_IDENTITY_DISABLE_MULTITENANTAUTH", "on"}, // Not a value that should be recognized
});
TokenRequestContext trc;
trc.TenantId = "bB";
EXPECT_THROW(
static_cast<void>(TenantIdResolver::Resolve("aA", trc, {})), AuthenticationException);
}

View File

@ -53,7 +53,7 @@ TEST(TokenCache, GetReuseRefresh)
auto const Yesterday = Tomorrow - 48h;
{
auto const token1 = tokenCache.GetToken("A", 2min, [=]() {
auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -65,7 +65,7 @@ TEST(TokenCache, GetReuseRefresh)
EXPECT_EQ(token1.ExpiresOn, Tomorrow);
EXPECT_EQ(token1.Token, "T1");
auto const token2 = tokenCache.GetToken("A", 2min, [=]() {
auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "T2";
@ -80,9 +80,9 @@ TEST(TokenCache, GetReuseRefresh)
}
{
tokenCache.m_cache["A"]->AccessToken.ExpiresOn = Yesterday;
tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday;
auto const token = tokenCache.GetToken("A", 2min, [=]() {
auto const token = tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T3";
result.ExpiresOn = Tomorrow + 1min;
@ -106,7 +106,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey)
tokenCache.m_onBeforeCacheWriteLock = [&]() {
tokenCache.m_onBeforeCacheWriteLock = nullptr;
static_cast<void>(tokenCache.GetToken("A", 2min, [=]() {
static_cast<void>(tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -114,7 +114,7 @@ TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey)
}));
};
auto const token = tokenCache.GetToken("A", 2min, [=]() {
auto const token = tokenCache.GetToken("A", {}, 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before "
"acquiring cache write lock");
AccessToken result;
@ -141,12 +141,12 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken)
tokenCache.m_onBeforeItemWriteLock = [&]() {
tokenCache.m_onBeforeItemWriteLock = nullptr;
auto const item = tokenCache.m_cache["A"];
auto const item = tokenCache.m_cache[{"A", {}}];
item->AccessToken.Token = "T1";
item->AccessToken.ExpiresOn = Tomorrow;
};
auto const token = tokenCache.GetToken("A", 2min, [=]() {
auto const token = tokenCache.GetToken("A", {}, 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before "
"acquiring item write lock");
AccessToken result;
@ -167,12 +167,12 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken)
tokenCache.m_onBeforeItemWriteLock = [&]() {
tokenCache.m_onBeforeItemWriteLock = nullptr;
auto const item = tokenCache.m_cache["A"];
auto const item = tokenCache.m_cache[{"A", {}}];
item->AccessToken.Token = "T3";
item->AccessToken.ExpiresOn = Yesterday;
};
auto const token = tokenCache.GetToken("A", 2min, [=]() {
auto const token = tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T4";
result.ExpiresOn = Tomorrow + 3min;
@ -199,7 +199,7 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 1; i <= 35; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(tokenCache.GetToken(n, 2min, [=]() {
static_cast<void>(tokenCache.GetToken(n, {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -214,14 +214,14 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday;
tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday;
}
// Add tokens up to 55 total. When 56th gets added, clean up should get triggered.
for (auto i = 36; i <= 55; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(tokenCache.GetToken(n, 2min, [=]() {
static_cast<void>(tokenCache.GetToken(n, {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -235,11 +235,11 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end());
EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end());
}
// One more addition to the cache and cleanup for the expired ones will get triggered.
static_cast<void>(tokenCache.GetToken("56", 2min, [=]() {
static_cast<void>(tokenCache.GetToken("56", {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -253,14 +253,14 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 1; i <= 3; ++i)
{
auto const n = std::to_string(i);
EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end());
EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end());
}
// Let's expire items from 21 all the way up to 56.
for (auto i = 21; i <= 56; ++i)
{
auto const n = std::to_string(i);
tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday;
tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday;
}
// Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to
@ -268,7 +268,7 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 2; i <= 3; ++i)
{
auto const n = std::to_string(i);
static_cast<void>(tokenCache.GetToken(n, 2min, [=]() {
static_cast<void>(tokenCache.GetToken(n, {}, 2min, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow;
@ -284,19 +284,19 @@ TEST(TokenCache, ExpiredCleanup)
// Out of 4 locked, two are expired, so they should get cleared under normal circumstances, but
// this time they will remain in the cache.
std::shared_lock<std::shared_timed_mutex> readLockForUnexpired(
tokenCache.m_cache["2"]->ElementMutex);
tokenCache.m_cache[{"2", {}}]->ElementMutex);
std::shared_lock<std::shared_timed_mutex> readLockForExpired(
tokenCache.m_cache["54"]->ElementMutex);
tokenCache.m_cache[{"54", {}}]->ElementMutex);
std::unique_lock<std::shared_timed_mutex> writeLockForUnexpired(
tokenCache.m_cache["3"]->ElementMutex);
tokenCache.m_cache[{"3", {}}]->ElementMutex);
std::unique_lock<std::shared_timed_mutex> writeLockForExpired(
tokenCache.m_cache["55"]->ElementMutex);
tokenCache.m_cache[{"55", {}}]->ElementMutex);
// Count is at 55. Inserting the 56th element, and it will trigger cleanup.
static_cast<void>(tokenCache.GetToken("1", 2min, [=]() {
static_cast<void>(tokenCache.GetToken("1", {}, 2min, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow;
@ -309,17 +309,17 @@ TEST(TokenCache, ExpiredCleanup)
for (auto i = 1; i <= 20; ++i)
{
auto const n = std::to_string(i);
EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end());
EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end());
}
EXPECT_NE(tokenCache.m_cache.find("54"), tokenCache.m_cache.end());
EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end());
EXPECT_NE(tokenCache.m_cache.find("55"), tokenCache.m_cache.end());
EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end());
for (auto i = 21; i <= 53; ++i)
{
auto const n = std::to_string(i);
EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end());
EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end());
}
}
@ -331,7 +331,7 @@ TEST(TokenCache, MinimumExpiration)
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const token1 = tokenCache.GetToken("A", 2min, [=]() {
auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -343,7 +343,7 @@ TEST(TokenCache, MinimumExpiration)
EXPECT_EQ(token1.ExpiresOn, Tomorrow);
EXPECT_EQ(token1.Token, "T1");
auto const token2 = tokenCache.GetToken("A", 24h, [=]() {
auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() {
AccessToken result;
result.Token = "T2";
result.ExpiresOn = Tomorrow + 1h;
@ -364,7 +364,7 @@ TEST(TokenCache, MultithreadedAccess)
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
auto const token1 = tokenCache.GetToken("A", 2min, [=]() {
auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() {
AccessToken result;
result.Token = "T1";
result.ExpiresOn = Tomorrow;
@ -377,14 +377,15 @@ TEST(TokenCache, MultithreadedAccess)
EXPECT_EQ(token1.Token, "T1");
{
std::shared_lock<std::shared_timed_mutex> itemReadLock(tokenCache.m_cache["A"]->ElementMutex);
std::shared_lock<std::shared_timed_mutex> itemReadLock(
tokenCache.m_cache[{"A", {}}]->ElementMutex);
{
std::shared_lock<std::shared_timed_mutex> cacheReadLock(tokenCache.m_cacheMutex);
// Parallel threads read both the container and the item we're accessing, and we can access it
// in parallel as well.
auto const token2 = tokenCache.GetToken("A", 2min, [=]() {
// Parallel threads read both the container and the item we're accessing, and we can
// access it in parallel as well.
auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "T2";
@ -400,7 +401,7 @@ TEST(TokenCache, MultithreadedAccess)
// The cache is unlocked, but one item is being read in a parallel thread, which does not
// prevent new items (with different key) from being appended to cache.
auto const token3 = tokenCache.GetToken("B", 2min, [=]() {
auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() {
AccessToken result;
result.Token = "T3";
result.ExpiresOn = Tomorrow + 2h;
@ -414,11 +415,12 @@ TEST(TokenCache, MultithreadedAccess)
}
{
std::unique_lock<std::shared_timed_mutex> itemWriteLock(tokenCache.m_cache["A"]->ElementMutex);
std::unique_lock<std::shared_timed_mutex> itemWriteLock(
tokenCache.m_cache[{"A", {}}]->ElementMutex);
// The cache is unlocked, but one item is being written in a parallel thread, which does not
// prevent new items (with different key) from being appended to cache.
auto const token3 = tokenCache.GetToken("C", 2min, [=]() {
auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() {
AccessToken result;
result.Token = "T4";
result.ExpiresOn = Tomorrow + 3h;
@ -514,7 +516,8 @@ TEST(TokenCache, PerCredInstance)
auto const tokenB = credB.GetToken(getCached, {});
EXPECT_EQ(
tokenB.Token,
"SecretB2"); // if token cache was shared between instances, the value would be "SecretA1"
"SecretB2"); // if token cache was shared between instances, the value would be
// "SecretA1"
}
{
@ -541,3 +544,128 @@ TEST(TokenCache, PerCredInstance)
EXPECT_EQ(tokenA6.Token, "SecretA4");
}
}
TEST(TokenCache, TenantId)
{
TestableTokenCache tokenCache;
EXPECT_EQ(tokenCache.m_cache.size(), 0UL);
DateTime const Tomorrow = std::chrono::system_clock::now() + 24h;
{
auto const token = tokenCache.GetToken("A", "X", 2min, [=]() {
AccessToken result;
result.Token = "AX";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 1UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "AX");
}
{
auto const token = tokenCache.GetToken("B", "X", 2min, [=]() {
AccessToken result;
result.Token = "BX";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 2UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "BX");
}
{
auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() {
AccessToken result;
result.Token = "AY";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 3UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "AY");
}
{
auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() {
AccessToken result;
result.Token = "BY";
result.ExpiresOn = Tomorrow;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 4UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "BY");
}
{
auto const token = tokenCache.GetToken("A", "X", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "XA";
result.ExpiresOn = Tomorrow + 24h;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 4UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "AX");
}
{
auto const token = tokenCache.GetToken("B", "X", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "XB";
result.ExpiresOn = Tomorrow + 24h;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 4UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "BX");
}
{
auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "YA";
result.ExpiresOn = Tomorrow + 24h;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 4UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "AY");
}
{
auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() {
EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good");
AccessToken result;
result.Token = "YB";
result.ExpiresOn = Tomorrow + 24h;
return result;
});
EXPECT_EQ(tokenCache.m_cache.size(), 4UL);
EXPECT_EQ(token.ExpiresOn, Tomorrow);
EXPECT_EQ(token.Token, "BY");
}
}

View File

@ -14,7 +14,7 @@
{
"name": "azure-core-cpp",
"default-features": false,
"version>=": "1.8.0"
"version>=": "1.9.0-beta.1"
},
"openssl",
{

View File

@ -4,6 +4,8 @@
### Features Added
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes
### Bugs Fixed

View File

@ -10,6 +10,7 @@
#include <azure/core/internal/json/json_optional.hpp>
#include <azure/core/internal/json/json_serializable.hpp>
#include <azure/keyvault/administration/settings_client.hpp>
#include <azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp>
#include <azure/keyvault/shared/keyvault_shared.hpp>
#include <memory>
@ -51,7 +52,8 @@ SettingsClient::SettingsClient(
tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)};
perRetrypolicies.emplace_back(
std::make_unique<BearerTokenAuthenticationPolicy>(credential, std::move(tokenContext)));
std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>(
credential, std::move(tokenContext)));
}
std::vector<std::unique_ptr<HttpPolicy>> perCallpolicies;

View File

@ -14,7 +14,7 @@
{
"name": "azure-core-cpp",
"default-features": false,
"version>=": "1.7.2"
"version>=": "1.9.0-beta.1"
},
{
"name": "vcpkg-cmake",

View File

@ -4,6 +4,8 @@
### Features Added
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes
### Bugs Fixed

View File

@ -3,6 +3,7 @@
#include "azure/keyvault/certificates/certificate_client.hpp"
#include "azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp"
#include "azure/keyvault/shared/keyvault_shared.hpp"
#include "private/certificate_constants.hpp"
#include "private/certificate_serializers.hpp"
@ -76,7 +77,8 @@ CertificateClient::CertificateClient(
tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)};
perRetrypolicies.emplace_back(
std::make_unique<BearerTokenAuthenticationPolicy>(credential, std::move(tokenContext)));
std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>(
credential, std::move(tokenContext)));
}
std::vector<std::unique_ptr<HttpPolicy>> perCallpolicies;

View File

@ -14,7 +14,7 @@
{
"name": "azure-core-cpp",
"default-features": false,
"version>=": "1.5.0"
"version>=": "1.9.0-beta.1"
},
{
"name": "vcpkg-cmake",

View File

@ -4,6 +4,8 @@
### Features Added
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes
### Bugs Fixed

View File

@ -6,6 +6,7 @@
#include <azure/core/http/http.hpp>
#include <azure/core/http/policies/policy.hpp>
#include <azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp>
#include <azure/keyvault/shared/keyvault_shared.hpp>
#include "azure/keyvault/keys/cryptography/cryptography_client.hpp"
@ -105,7 +106,8 @@ CryptographyClient::CryptographyClient(
tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_keyId)};
perRetrypolicies.emplace_back(
std::make_unique<BearerTokenAuthenticationPolicy>(credential, tokenContext));
std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>(
credential, tokenContext));
}
std::vector<std::unique_ptr<HttpPolicy>> perCallpolicies;

View File

@ -5,6 +5,7 @@
#include <azure/core/http/http.hpp>
#include <azure/core/http/policies/policy.hpp>
#include <azure/core/internal/http/pipeline.hpp>
#include <azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp>
#include <azure/keyvault/shared/keyvault_shared.hpp>
#include "azure/keyvault/keys/key_client.hpp"
@ -75,7 +76,8 @@ KeyClient::KeyClient(
tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(m_vaultUrl)};
perRetrypolicies.emplace_back(
std::make_unique<BearerTokenAuthenticationPolicy>(credential, std::move(tokenContext)));
std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>(
credential, std::move(tokenContext)));
}
std::vector<std::unique_ptr<HttpPolicy>> perCallpolicies;

View File

@ -14,7 +14,7 @@
{
"name": "azure-core-cpp",
"default-features": false,
"version>=": "1.5.0"
"version>=": "1.9.0-beta.1"
},
{
"name": "vcpkg-cmake",

View File

@ -4,6 +4,8 @@
### Features Added
- Added support for challenge-based and multi-tenant authentication.
### Breaking Changes
### Bugs Fixed

View File

@ -17,6 +17,7 @@
#include <azure/core/credentials/credentials.hpp>
#include <azure/core/http/http.hpp>
#include <azure/core/http/policies/policy.hpp>
#include <azure/keyvault/shared/keyvault_challenge_based_authentication_policy.hpp>
#include <azure/keyvault/shared/keyvault_shared.hpp>
#include <algorithm>
@ -72,7 +73,8 @@ SecretClient::SecretClient(
tokenContext.Scopes = {_internal::UrlScope::GetScopeFromUrl(url)};
perRetrypolicies.emplace_back(
std::make_unique<BearerTokenAuthenticationPolicy>(credential, tokenContext));
std::make_unique<_internal::KeyVaultChallengeBasedAuthenticationPolicy>(
credential, tokenContext));
}
std::vector<std::unique_ptr<HttpPolicy>> perCallpolicies;

View File

@ -16,6 +16,7 @@ SetUpTestProxy("sdk/keyvault")
add_executable (
azure-security-keyvault-secrets-test
challenge_based_authentication_policy_test.cpp
macro_guard.cpp
secret_client_test.cpp
secret_get_client_deserialize_test.hpp
@ -38,7 +39,12 @@ endif()
target_link_libraries(azure-security-keyvault-secrets-test PRIVATE azure-security-keyvault-secrets azure-identity azure-core-test-fw gtest gtest_main gmock)
# Adding private headers so we can test the private APIs with no relative paths include.
target_include_directories (azure-security-keyvault-secrets-test PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../src>)
target_include_directories(
azure-security-keyvault-secrets-test
PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../src>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../../azure-security-keyvault-shared/inc>
)
# gtest_add_tests will scan the test from azure-core-test and call add_test
# for each test to ctest. This enables `ctest -r` to run specific tests directly.

View File

@ -14,7 +14,7 @@
{
"name": "azure-core-cpp",
"default-features": false,
"version>=": "1.5.0"
"version>=": "1.9.0-beta.1"
},
{
"name": "vcpkg-cmake",

View File

@ -0,0 +1,188 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Key Vault Challenge-Based Authentication Policy.
*
*/
#pragma once
#include <azure/core/http/policies/policy.hpp>
#include <azure/core/internal/credentials/authorization_challenge_parser.hpp>
#include <stdexcept>
namespace Azure { namespace Security { namespace KeyVault { namespace _internal {
/**
* @brief Challenge-Based Authentication Policy for Key Vault.
*
*/
class KeyVaultChallengeBasedAuthenticationPolicy final
: public Core::Http::Policies::_internal::BearerTokenAuthenticationPolicy {
private:
mutable Core::Credentials::TokenRequestContext m_tokenRequestContext;
public:
explicit KeyVaultChallengeBasedAuthenticationPolicy(
std::shared_ptr<Core::Credentials::TokenCredential const> credential,
Core::Credentials::TokenRequestContext tokenRequestContext)
: BearerTokenAuthenticationPolicy(credential, tokenRequestContext),
m_tokenRequestContext(tokenRequestContext)
{
}
std::unique_ptr<HttpPolicy> Clone() const override
{
return std::make_unique<KeyVaultChallengeBasedAuthenticationPolicy>(*this);
}
private:
std::unique_ptr<Core::Http::RawResponse> AuthorizeAndSendRequest(
Core::Http::Request& request,
Core::Http::Policies::NextHttpPolicy& nextPolicy,
Core::Context const& context) const override
{
AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context);
return nextPolicy.Send(request, context);
}
bool AuthorizeRequestOnChallenge(
std::string const& challenge,
Core::Http::Request& request,
Core::Context const& context) const override
{
auto const scope = GetScope(challenge);
if (scope.empty())
{
return false;
}
ValidateChallengeResponse(scope, request.GetUrl().GetHost());
auto const tenantId = GetTenantId(GetAuthorization(challenge));
m_tokenRequestContext.TenantId = tenantId;
m_tokenRequestContext.Scopes = {scope};
AuthenticateAndAuthorizeRequest(request, m_tokenRequestContext, context);
return true;
}
static std::string TrimTrailingSlash(std::string const& s)
{
return (s.empty() || s.back() != '/') ? s : s.substr(0, s.size() - 1);
}
static std::string GetScope(std::string const& challenge)
{
using Core::Credentials::_internal::AuthorizationChallengeParser;
auto resource
= AuthorizationChallengeParser::GetChallengeParameter(challenge, "Bearer", "resource");
return !resource.empty()
? (TrimTrailingSlash(resource) + "/.default")
: AuthorizationChallengeParser::GetChallengeParameter(challenge, "Bearer", "scope");
}
static std::string GetAuthorization(std::string const& challenge)
{
using Core::Credentials::_internal::AuthorizationChallengeParser;
auto authorization = AuthorizationChallengeParser::GetChallengeParameter(
challenge, "Bearer", "authorization");
return !authorization.empty() ? authorization
: AuthorizationChallengeParser::GetChallengeParameter(
challenge, "Bearer", "authorization_uri");
}
static bool TryParseUrl(std::string const& s, Core::Url& outUrl)
{
using Core::Url;
try
{
outUrl = Url(s);
}
catch (std::out_of_range const&)
{
return false;
}
catch (std::invalid_argument const&)
{
return false;
}
return true;
}
static void ValidateChallengeResponse(std::string const& scope, std::string const& requestHost)
{
using Core::Url;
using Core::Credentials::AuthenticationException;
Url scopeUrl;
if (!TryParseUrl(scope, scopeUrl))
{
throw AuthenticationException("The challenge contains invalid scope '" + scope + "'.");
}
auto const& scopeHost = scopeUrl.GetHost();
// Check whether requestHost.ends_with(scopeHost)
auto const requestHostLength = requestHost.length();
auto const scopeHostLength = scopeHost.length();
bool domainMismatch = requestHostLength < scopeHostLength;
if (!domainMismatch)
{
auto const requestHostOffset = requestHostLength - scopeHostLength;
for (size_t i = 0; i < scopeHostLength; ++i)
{
if (requestHost[requestHostOffset + i] != scopeHost[i])
{
domainMismatch = true;
break;
}
}
}
if (domainMismatch)
{
throw AuthenticationException(
"The challenge resource '" + scopeHost + "' does not match the requested domain.");
}
}
static std::string GetTenantId(std::string const& authorization)
{
using Core::Url;
using Core::Credentials::AuthenticationException;
if (!authorization.empty())
{
Url authorizationUrl;
if (TryParseUrl(authorization, authorizationUrl))
{
auto const& path = authorizationUrl.GetPath();
if (!path.empty())
{
auto const firstSlash = path.find('/');
if (firstSlash == std::string::npos)
{
return path;
}
else if (firstSlash > 0)
{
return path.substr(0, firstSlash);
}
}
}
}
throw AuthenticationException(
"The challenge authorization URI '" + authorization + "' is invalid.");
}
};
}}}} // namespace Azure::Security::KeyVault::_internal