Add ADFS support for ClientSecretCredential (#1947)

This commit is contained in:
Anton Kolesnyk 2021-03-30 00:26:42 +00:00 committed by GitHub
parent b4110380f0
commit b606ff60dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 559 additions and 29 deletions

View File

@ -46,6 +46,7 @@ set(
${CURL_TRANSPORT_ADAPTER_INC}
${WIN_TRANSPORT_ADAPTER_INC}
inc/azure/core/credentials/credentials.hpp
inc/azure/core/credentials/token_credential_options.hpp
inc/azure/core/cryptography/hash.hpp
inc/azure/core/diagnostics/logger.hpp
inc/azure/core/http/http.hpp

View File

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Options for #Azure::Core::Credentials::TokenCredential.
*/
#pragma once
#include "azure/core/internal/client_options.hpp"
namespace Azure { namespace Core { namespace Credentials {
/**
* @brief Defines options for #Azure::Core::Credentials::TokenCredential.
*/
struct TokenCredentialOptions : public Azure::Core::_internal::ClientOptions
{
};
}}} // namespace Azure::Core::Credentials

View File

@ -40,6 +40,20 @@ DateTime GetSystemClockEpoch()
static_cast<int8_t>(systemClockEpochUtcStructTm->tm_sec));
}
DateTime GetMaxDateTime()
{
auto const systemClockMax = std::chrono::duration_cast<DateTime::clock::duration>(
std::chrono::system_clock::time_point::max().time_since_epoch())
.count();
auto const systemClockEpoch = GetSystemClockEpoch().time_since_epoch().count();
constexpr auto repMax = std::numeric_limits<DateTime::clock::duration::rep>::max();
return DateTime(DateTime::time_point(
DateTime::duration(systemClockMax + std::min(systemClockEpoch, (repMax - systemClockMax)))));
}
template <typename T>
void ValidateDateElementRange(
T value,
@ -409,7 +423,7 @@ DateTime::DateTime(
DateTime::operator std::chrono::system_clock::time_point() const
{
static DateTime SystemClockMin(std::chrono::system_clock::time_point::min());
static DateTime SystemClockMax(std::chrono::system_clock::time_point::max());
static DateTime SystemClockMax(GetMaxDateTime());
auto outOfRange = 0;
if (*this < SystemClockMin)

View File

@ -11,8 +11,8 @@
#include "azure/identity/dll_import_export.hpp"
#include <azure/core/credentials/credentials.hpp>
#include <azure/core/credentials/token_credential_options.hpp>
#include <azure/core/http/policies/policy.hpp>
#include <azure/core/internal/client_options.hpp>
#include <string>
#include <utility>
@ -25,7 +25,7 @@ namespace Azure { namespace Identity {
/**
* @brief Defines options for token authentication.
*/
struct ClientSecretCredentialOptions : public Azure::Core::_internal::ClientOptions
struct ClientSecretCredentialOptions : public Azure::Core::Credentials::TokenCredentialOptions
{
public:
/**
@ -64,12 +64,32 @@ namespace Azure { namespace Identity {
std::string tenantId,
std::string clientId,
std::string clientSecret,
ClientSecretCredentialOptions options = ClientSecretCredentialOptions())
ClientSecretCredentialOptions options)
: m_tenantId(std::move(tenantId)), m_clientId(std::move(clientId)),
m_clientSecret(std::move(clientSecret)), m_options(std::move(options))
{
}
/**
* @brief Construct a Client Secret credential.
*
* @param tenantId Tenant ID.
* @param clientId Client ID.
* @param clientSecret Client Secret.
* @param options #Azure::Core::Credentials::TokenCredentialOptions.
*/
explicit ClientSecretCredential(
std::string tenantId,
std::string clientId,
std::string clientSecret,
Azure::Core::Credentials::TokenCredentialOptions const& options
= Azure::Core::Credentials::TokenCredentialOptions())
: m_tenantId(std::move(tenantId)), m_clientId(std::move(clientId)),
m_clientSecret(std::move(clientSecret))
{
static_cast<Azure::Core::Credentials::TokenCredentialOptions&>(m_options) = options;
}
Core::Credentials::AccessToken GetToken(
Core::Credentials::TokenRequestContext const& tokenRequestContext,
Core::Context const& context) const override;

View File

@ -9,11 +9,11 @@
#pragma once
#include <azure/core/credentials/credentials.hpp>
#include <azure/core/credentials/token_credential_options.hpp>
#include <memory>
namespace Azure { namespace Identity {
/**
* @brief An environment credential.
*/
@ -32,7 +32,9 @@ namespace Azure { namespace Identity {
* - AZURE_USERNAME
* - AZURE_PASSWORD
*/
explicit EnvironmentCredential();
explicit EnvironmentCredential(
Azure::Core::Credentials::TokenCredentialOptions options
= Azure::Core::Credentials::TokenCredentialOptions());
Core::Credentials::AccessToken GetToken(
Core::Credentials::TokenRequestContext const& tokenRequestContext,

View File

@ -9,6 +9,40 @@
#include <chrono>
#include <sstream>
namespace {
// Assumes !scopes.empty()
std::string FormatScopes(std::vector<std::string> const& scopes, bool asResource)
{
if (asResource && scopes.size() == 1)
{
auto resource = scopes[0];
constexpr char suffix[] = "/.default";
constexpr int suffixLen = sizeof(suffix) - 1;
auto const resourceLen = resource.length();
// If scopes[0] ends with '/.default', remove it.
if (resourceLen >= suffixLen
&& resource.find(suffix, resourceLen - suffixLen) != std::string::npos)
{
resource = resource.substr(0, resourceLen - suffixLen);
}
return Azure::Core::Url::Encode(resource);
}
auto scopesIter = scopes.begin();
auto scopesStr = Azure::Core::Url::Encode(*scopesIter);
auto const scopesEnd = scopes.end();
for (++scopesIter; scopesIter != scopesEnd; ++scopesIter)
{
scopesStr += std::string(" ") + Azure::Core::Url::Encode(*scopesIter);
}
return scopesStr;
}
} // namespace
using namespace Azure::Identity;
std::string const Azure::Identity::_detail::g_aadGlobalAuthority
@ -27,24 +61,21 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken(
static std::string const errorMsgPrefix("ClientSecretCredential::GetToken: ");
try
{
auto const isAdfs = m_tenantId == "adfs";
Url url(m_options.AuthorityHost);
url.AppendPath(m_tenantId);
url.AppendPath("oauth2/v2.0/token");
url.AppendPath(isAdfs ? "oauth2/token" : "oauth2/v2.0/token");
std::ostringstream body;
body << "grant_type=client_credentials&client_id=" << Url::Encode(m_clientId)
<< "&client_secret=" << Url::Encode(m_clientSecret);
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
auto scopesIter = scopes.begin();
body << "&scope=" << Url::Encode(*scopesIter);
auto const scopesEnd = scopes.end();
for (++scopesIter; scopesIter != scopesEnd; ++scopesIter)
auto const& scopes = tokenRequestContext.Scopes;
if (!scopes.empty())
{
body << " " << *scopesIter;
body << "&scope=" << FormatScopes(scopes, isAdfs);
}
}
@ -58,6 +89,11 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken(
request.SetHeader("Content-Type", "application/x-www-form-urlencoded");
request.SetHeader("Content-Length", std::to_string(bodyString.size()));
if (isAdfs)
{
request.SetHeader("Host", url.GetHost());
}
HttpPipeline httpPipeline(m_options, "Identity-client-secret-credential", "", {}, {});
std::shared_ptr<RawResponse> response = httpPipeline.Send(request, context);

View File

@ -21,7 +21,8 @@
using namespace Azure::Identity;
EnvironmentCredential::EnvironmentCredential()
EnvironmentCredential::EnvironmentCredential(
Azure::Core::Credentials::TokenCredentialOptions options)
{
#if !defined(WINAPI_PARTITION_DESKTOP) \
|| WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation.
@ -53,28 +54,29 @@ EnvironmentCredential::EnvironmentCredential()
{
if (authority != nullptr)
{
ClientSecretCredentialOptions options;
options.AuthorityHost = authority;
ClientSecretCredentialOptions clientSecretCredentialOptions;
static_cast<Core::_internal::ClientOptions&>(clientSecretCredentialOptions) = options;
clientSecretCredentialOptions.AuthorityHost = authority;
m_credentialImpl.reset(
new ClientSecretCredential(tenantId, clientId, clientSecret, options));
m_credentialImpl.reset(new ClientSecretCredential(
tenantId, clientId, clientSecret, clientSecretCredentialOptions));
}
else
{
m_credentialImpl.reset(new ClientSecretCredential(tenantId, clientId, clientSecret));
m_credentialImpl.reset(
new ClientSecretCredential(tenantId, clientId, clientSecret, options));
}
}
// TODO: These credential types are not implemented. Uncomment when implemented.
// else if (username != nullptr && password != nullptr)
//{
// m_credentialImpl.reset(
// new UsernamePasswordCredential(username, password, tenantId, clientId));
//}
// {
// m_credentialImpl.reset(
// new UsernamePasswordCredential(tenantId, clientId, username, password, options));
// }
// else if (clientCertificatePath != nullptr)
//{
// m_credentialImpl.reset(
// new ClientCertificateCredential(tenantId, clientId, clientCertificatePath));
//}
// {
// m_credentialImpl.reset(new ClientCertificateCredential(tenantId, clientId, options));
// }
}
#endif
}

View File

@ -13,9 +13,12 @@ include(GoogleTest)
add_executable (
azure-identity-test
client_secret_credential.cpp
environment_credential.cpp
macro_guard.cpp
main.cpp
simplified_header.cpp
test_transport.hpp
)
if (MSVC)

View File

@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/identity/client_secret_credential.hpp"
#include <azure/core/io/body_stream.hpp>
#include "test_transport.hpp"
#include <gtest/gtest.h>
using namespace Azure::Identity;
namespace {
struct CredentialResult
{
struct RequestInfo
{
std::string AbsoluteUrl;
Azure::Core::CaseInsensitiveMap Headers;
std::string Body;
} Request;
struct
{
std::chrono::system_clock::time_point Earliest;
std::chrono::system_clock::time_point Latest;
Azure::Core::Credentials::AccessToken AccessToken;
} Response;
};
CredentialResult TestClientSecretCredential(
std::string const& tenantId,
std::string const& clientId,
std::string const& clientSecret,
ClientSecretCredentialOptions credentialOptions,
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
std::string const& responseBody)
{
CredentialResult result;
auto responseVec = std::vector<uint8_t>(responseBody.begin(), responseBody.end());
credentialOptions.Transport.Transport = std::make_shared<TestTransport>([&](auto request, auto) {
auto const bodyVec = request.GetBodyStream()->ReadToEnd(Azure::Core::Context());
result.Request
= {request.GetUrl().GetAbsoluteUrl(),
request.GetHeaders(),
std::string(bodyVec.begin(), bodyVec.end())};
auto response = std::make_unique<Azure::Core::Http::RawResponse>(
1, 1, Azure::Core::Http::HttpStatusCode::Ok, "OK");
response->SetBodyStream(std::make_unique<Azure::Core::IO::MemoryBodyStream>(responseVec));
result.Response.Earliest = std::chrono::system_clock::now();
return response;
});
ClientSecretCredential credential(tenantId, clientId, clientSecret, credentialOptions);
result.Response.AccessToken = credential.GetToken(tokenRequestContext, Azure::Core::Context());
result.Response.Latest = std::chrono::system_clock::now();
return result;
}
} // namespace
TEST(ClientSecretCredential, Regular)
{
ClientSecretCredentialOptions options;
options.AuthorityHost = "https://microsoft.com/";
auto const actual = TestClientSecretCredential(
"01234567-89ab-cdef-fedc-ba8976543210",
"fedcba98-7654-3210-0123-456789abcdef",
"CLIENTSECRET",
options,
{{"https://azure.com/.default"}},
"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}");
EXPECT_EQ(
actual.Request.AbsoluteUrl,
"https://microsoft.com/01234567-89ab-cdef-fedc-ba8976543210/oauth2/v2.0/token");
{
constexpr char expectedBody[] = "grant_type=client_credentials"
"&client_id=fedcba98-7654-3210-0123-456789abcdef"
"&client_secret=CLIENTSECRET"
"&scope=https%3A%2F%2Fazure.com%2F.default";
EXPECT_EQ(actual.Request.Body, expectedBody);
EXPECT_NE(actual.Request.Headers.find("Content-Length"), actual.Request.Headers.end());
EXPECT_EQ(
actual.Request.Headers.at("Content-Length"), std::to_string(sizeof(expectedBody) - 1));
}
EXPECT_NE(actual.Request.Headers.find("Content-Type"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Content-Type"), "application/x-www-form-urlencoded");
EXPECT_EQ(actual.Response.AccessToken.Token, "ACCESSTOKEN1");
using namespace std::chrono_literals;
EXPECT_GT(actual.Response.AccessToken.ExpiresOn, actual.Response.Earliest + 3600s);
EXPECT_LT(actual.Response.AccessToken.ExpiresOn, actual.Response.Latest + 3600s);
}
TEST(ClientSecretCredential, AzureStack)
{
ClientSecretCredentialOptions options;
options.AuthorityHost = "https://microsoft.com/";
auto const actual = TestClientSecretCredential(
"adfs",
"fedcba98-7654-3210-0123-456789abcdef",
"CLIENTSECRET",
options,
{{"https://azure.com/.default"}},
"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}");
EXPECT_EQ(actual.Request.AbsoluteUrl, "https://microsoft.com/adfs/oauth2/token");
{
constexpr char expectedBody[] = "grant_type=client_credentials"
"&client_id=fedcba98-7654-3210-0123-456789abcdef"
"&client_secret=CLIENTSECRET"
"&scope=https%3A%2F%2Fazure.com";
EXPECT_EQ(actual.Request.Body, expectedBody);
EXPECT_NE(actual.Request.Headers.find("Content-Length"), actual.Request.Headers.end());
EXPECT_EQ(
actual.Request.Headers.at("Content-Length"), std::to_string(sizeof(expectedBody) - 1));
}
EXPECT_NE(actual.Request.Headers.find("Content-Type"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Content-Type"), "application/x-www-form-urlencoded");
EXPECT_NE(actual.Request.Headers.find("Host"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Host"), "microsoft.com");
EXPECT_EQ(actual.Response.AccessToken.Token, "ACCESSTOKEN1");
using namespace std::chrono_literals;
EXPECT_GT(actual.Response.AccessToken.ExpiresOn, actual.Response.Earliest + 3600s);
EXPECT_LT(actual.Response.AccessToken.ExpiresOn, actual.Response.Latest + 3600s);
}

View File

@ -0,0 +1,257 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/identity/environment_credential.hpp"
#include <azure/core/io/body_stream.hpp>
#include <azure/core/platform.hpp>
#include "test_transport.hpp"
#include <gtest/gtest.h>
#include <stdlib.h>
#if defined(AZ_PLATFORM_WINDOWS)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#if !defined(NOMINMAX)
#define NOMINMAX
#endif
#include <windows.h>
#endif
#if !defined(WINAPI_PARTITION_DESKTOP) \
|| WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation.
using namespace Azure::Identity;
namespace {
class EnvironmentOverride {
class Environment {
static void SetVariable(std::string const& name, std::string const& value)
{
#if defined(_MSC_VER)
static_cast<void>(_putenv((name + "=" + value).c_str()));
#else
if (value.empty())
{
static_cast<void>(unsetenv(name.c_str()));
}
else
{
static_cast<void>(setenv(name.c_str(), value.c_str(), 1));
}
#endif
}
public:
static std::string GetVariable(std::string const& name)
{
#if defined(_MSC_VER)
#pragma warning(push)
// warning C4996: 'getenv': This function or variable may be unsafe. Consider using _dupenv_s
// instead.
#pragma warning(disable : 4996)
#endif
auto const result = std::getenv(name.c_str());
return result != nullptr ? result : "";
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
}
static void SetVariables(std::map<std::string, std::string> const& vars)
{
for (auto var : vars)
{
SetVariable(var.first, var.second);
}
}
};
std::map<std::string, std::string> m_originalEnv;
public:
~EnvironmentOverride() { Environment::SetVariables(m_originalEnv); }
EnvironmentOverride(
std::string const& tenantId,
std::string const& clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::string const& username,
std::string const& password,
std::string const& clientCertificatePath)
{
std::map<std::string, std::string> const NewEnv = {
{"AZURE_TENANT_ID", tenantId},
{"AZURE_CLIENT_ID", clientId},
{"AZURE_CLIENT_SECRET", clientSecret},
{"AZURE_AUTHORITY_HOST", authorityHost},
{"AZURE_USERNAME", username},
{"AZURE_PASSWORD", password},
{"AZURE_CLIENT_CERTIFICATE_PATH", clientCertificatePath},
};
for (auto var : NewEnv)
{
m_originalEnv[var.first] = Environment::GetVariable(var.first);
}
try
{
Environment::SetVariables(NewEnv);
}
catch (...)
{
Environment::SetVariables(m_originalEnv);
throw;
}
}
};
struct CredentialResult
{
struct
{
std::string AbsoluteUrl;
Azure::Core::CaseInsensitiveMap Headers;
std::string Body;
} Request;
struct
{
std::chrono::system_clock::time_point Earliest;
std::chrono::system_clock::time_point Latest;
Azure::Core::Credentials::AccessToken AccessToken;
} Response;
};
CredentialResult TestEnvironmentCredential(
std::string const& tenantId,
std::string const& clientId,
std::string const& clientSecret,
std::string const& authorityHost,
std::string const& username,
std::string const& password,
std::string const& clientCertificatePath,
Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext,
std::string const& responseBody)
{
CredentialResult result;
auto responseVec = std::vector<uint8_t>(responseBody.begin(), responseBody.end());
Azure::Core::Credentials::TokenCredentialOptions credentialOptions;
credentialOptions.Transport.Transport = std::make_shared<TestTransport>([&](auto request, auto) {
auto const bodyVec = request.GetBodyStream()->ReadToEnd(Azure::Core::Context());
result.Request
= {request.GetUrl().GetAbsoluteUrl(),
request.GetHeaders(),
std::string(bodyVec.begin(), bodyVec.end())};
auto response = std::make_unique<Azure::Core::Http::RawResponse>(
1, 1, Azure::Core::Http::HttpStatusCode::Ok, "OK");
response->SetBodyStream(std::make_unique<Azure::Core::IO::MemoryBodyStream>(responseVec));
result.Response.Earliest = std::chrono::system_clock::now();
return response;
});
EnvironmentOverride env(
tenantId, clientId, clientSecret, authorityHost, username, password, clientCertificatePath);
EnvironmentCredential credential(credentialOptions);
result.Response.AccessToken = credential.GetToken(tokenRequestContext, Azure::Core::Context());
result.Response.Latest = std::chrono::system_clock::now();
return result;
}
} // namespace
TEST(EnvironmentCredential, RegularClientSecretCredential)
{
auto const actual = TestEnvironmentCredential(
"01234567-89ab-cdef-fedc-ba8976543210",
"fedcba98-7654-3210-0123-456789abcdef",
"CLIENTSECRET",
"https://microsoft.com/",
"",
"",
"",
{{"https://azure.com/.default"}},
"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}");
EXPECT_EQ(
actual.Request.AbsoluteUrl,
"https://microsoft.com/01234567-89ab-cdef-fedc-ba8976543210/oauth2/v2.0/token");
{
constexpr char expectedBody[] = "grant_type=client_credentials"
"&client_id=fedcba98-7654-3210-0123-456789abcdef"
"&client_secret=CLIENTSECRET"
"&scope=https%3A%2F%2Fazure.com%2F.default";
EXPECT_EQ(actual.Request.Body, expectedBody);
EXPECT_NE(actual.Request.Headers.find("Content-Length"), actual.Request.Headers.end());
EXPECT_EQ(
actual.Request.Headers.at("Content-Length"), std::to_string(sizeof(expectedBody) - 1));
}
EXPECT_NE(actual.Request.Headers.find("Content-Type"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Content-Type"), "application/x-www-form-urlencoded");
EXPECT_EQ(actual.Response.AccessToken.Token, "ACCESSTOKEN1");
using namespace std::chrono_literals;
EXPECT_GT(actual.Response.AccessToken.ExpiresOn, actual.Response.Earliest + 3600s);
EXPECT_LT(actual.Response.AccessToken.ExpiresOn, actual.Response.Latest + 3600s);
}
TEST(EnvironmentCredential, AzureStackClientSecretCredential)
{
auto const actual = TestEnvironmentCredential(
"adfs",
"fedcba98-7654-3210-0123-456789abcdef",
"CLIENTSECRET",
"https://microsoft.com/",
"",
"",
"",
{{"https://azure.com/.default"}},
"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN1\"}");
EXPECT_EQ(actual.Request.AbsoluteUrl, "https://microsoft.com/adfs/oauth2/token");
{
constexpr char expectedBody[] = "grant_type=client_credentials"
"&client_id=fedcba98-7654-3210-0123-456789abcdef"
"&client_secret=CLIENTSECRET"
"&scope=https%3A%2F%2Fazure.com";
EXPECT_EQ(actual.Request.Body, expectedBody);
EXPECT_NE(actual.Request.Headers.find("Content-Length"), actual.Request.Headers.end());
EXPECT_EQ(
actual.Request.Headers.at("Content-Length"), std::to_string(sizeof(expectedBody) - 1));
}
EXPECT_NE(actual.Request.Headers.find("Content-Type"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Content-Type"), "application/x-www-form-urlencoded");
EXPECT_NE(actual.Request.Headers.find("Host"), actual.Request.Headers.end());
EXPECT_EQ(actual.Request.Headers.at("Host"), "microsoft.com");
EXPECT_EQ(actual.Response.AccessToken.Token, "ACCESSTOKEN1");
using namespace std::chrono_literals;
EXPECT_GT(actual.Response.AccessToken.ExpiresOn, actual.Response.Earliest + 3600s);
EXPECT_LT(actual.Response.AccessToken.ExpiresOn, actual.Response.Latest + 3600s);
}
#endif

View File

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#pragma once
#include <functional>
#include <azure/core/http/transport.hpp>
class TestTransport : public Azure::Core::Http::HttpTransport {
public:
typedef std::function<std::unique_ptr<Azure::Core::Http::RawResponse>(
Azure::Core::Http::Request& request,
Azure::Core::Context const& context)>
SendCallback;
private:
SendCallback m_sendCallback;
public:
TestTransport(SendCallback send) : m_sendCallback(send) {}
std::unique_ptr<Azure::Core::Http::RawResponse> Send(
Azure::Core::Http::Request& request,
Azure::Core::Context const& context) override
{
return m_sendCallback(request, context);
}
};