From 09db139a71d5e4ef2838f3b6aa886ff74a2f0013 Mon Sep 17 00:00:00 2001 From: Anton Kolesnyk <41349689+antkmsft@users.noreply.github.com> Date: Mon, 5 Dec 2022 13:15:19 -0800 Subject: [PATCH] AzureCliCredential (#4146) * AzureCliCredential Co-authored-by: Anton Kolesnyk --- .../inc/azure/core/internal/unique_handle.hpp | 3 +- sdk/identity/azure-identity/CHANGELOG.md | 2 + sdk/identity/azure-identity/CMakeLists.txt | 2 + .../azure-identity/inc/azure/identity.hpp | 1 + .../azure/identity/azure_cli_credential.hpp | 94 +++ .../identity/client_secret_credential.hpp | 1 - .../azure/identity/environment_credential.hpp | 1 + .../azure-identity/samples/CMakeLists.txt | 5 + .../samples/azure_cli_credential.cpp | 34 + .../src/azure_cli_credential.cpp | 644 ++++++++++++++++++ .../src/client_secret_credential.cpp | 5 +- .../src/private/token_cache_internals.hpp | 8 +- .../src/private/token_credential_impl.hpp | 33 +- .../src/token_credential_impl.cpp | 325 ++++++--- .../azure-identity/test/ut/CMakeLists.txt | 1 + .../test/ut/azure_cli_credential_test.cpp | 291 ++++++++ .../test/ut/token_credential_impl_test.cpp | 3 +- 17 files changed, 1333 insertions(+), 120 deletions(-) create mode 100644 sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp create mode 100644 sdk/identity/azure-identity/samples/azure_cli_credential.cpp create mode 100644 sdk/identity/azure-identity/src/azure_cli_credential.cpp create mode 100644 sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp diff --git a/sdk/core/azure-core/inc/azure/core/internal/unique_handle.hpp b/sdk/core/azure-core/inc/azure/core/internal/unique_handle.hpp index 10a69b715..de9b77584 100644 --- a/sdk/core/azure-core/inc/azure/core/internal/unique_handle.hpp +++ b/sdk/core/azure-core/inc/azure/core/internal/unique_handle.hpp @@ -62,5 +62,6 @@ namespace Azure { namespace Core { namespace _internal { template struct UniqueHandleHelper; // *** Now users can say UniqueHandle if they want: - template using UniqueHandle = typename UniqueHandleHelper::type; + template class U = UniqueHandleHelper> + using UniqueHandle = typename U::type; }}} // namespace Azure::Core::_internal diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 8048ced0f..1eb114046 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added Azure CLI Credential. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index 6b686d633..6232727af 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -46,6 +46,7 @@ endif() set( AZURE_IDENTITY_HEADER + inc/azure/identity/azure_cli_credential.hpp inc/azure/identity/chained_token_credential.hpp inc/azure/identity/client_certificate_credential.hpp inc/azure/identity/client_secret_credential.hpp @@ -63,6 +64,7 @@ set( src/private/token_cache.hpp src/private/token_cache_internals.hpp src/private/token_credential_impl.hpp + src/azure_cli_credential.cpp src/chained_token_credential.cpp src/client_certificate_credential.cpp src/client_secret_credential.cpp diff --git a/sdk/identity/azure-identity/inc/azure/identity.hpp b/sdk/identity/azure-identity/inc/azure/identity.hpp index ff96cd03f..6afce616a 100644 --- a/sdk/identity/azure-identity/inc/azure/identity.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity.hpp @@ -8,6 +8,7 @@ #pragma once +#include "azure/identity/azure_cli_credential.hpp" #include "azure/identity/chained_token_credential.hpp" #include "azure/identity/client_certificate_credential.hpp" #include "azure/identity/client_secret_credential.hpp" diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp new file mode 100644 index 000000000..910aa7fa7 --- /dev/null +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Azure CLI Credential uses Azure CLI to obtain an access token. + */ + +#pragma once + +#include +#include + +#include + +#include +#include + +namespace Azure { namespace Identity { + /** + * @brief Options for configuring the #Azure::Identity::AzureCliCredential. + */ + struct AzureCliCredentialOptions final : public Core::Credentials::TokenCredentialOptions + { + /** + * @brief The ID of the tenant to which the credential will authenticate by default. If not + * specified, the credential will authenticate to any requested tenant, and will default to the + * tenant provided to the 'az login' command. + */ + std::string TenantId; + + /** + * @brief The CLI process timeout. + */ + DateTime::duration CliProcessTimeout + = std::chrono::seconds(13); // Value was taken from .NET SDK. + }; + + /** + * @brief Enables authentication to Azure Active Directory using Azure CLI to obtain an access + * token. + */ + class AzureCliCredential +#if !defined(TESTING_BUILD) + final +#endif + : public Core::Credentials::TokenCredential { + protected: + std::string m_tenantId; + DateTime::duration m_cliProcessTimeout; + + private: + explicit AzureCliCredential( + std::string tenantId, + DateTime::duration cliProcessTimeout, + Core::Credentials::TokenCredentialOptions const& options); + + public: + /** + * @brief Constructs an Azure CLI Credential. + * + * @param options Options for token retrieval. + */ + explicit AzureCliCredential(AzureCliCredentialOptions const& options = {}); + + /** + * @brief Constructs an Azure CLI Credential. + * + * @param options Options for token retrieval. + */ + explicit AzureCliCredential(Core::Credentials::TokenCredentialOptions const& options); + + /** + * @brief Gets an authentication token. + * + * @param tokenRequestContext A context to get the token in. + * @param context A context to control the request lifetime. + * + * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. + */ + Core::Credentials::AccessToken GetToken( + Core::Credentials::TokenRequestContext const& tokenRequestContext, + Core::Context const& context) const override; + +#if !defined(TESTING_BUILD) + private: +#else + protected: +#endif + virtual std::string GetAzCommand(std::string const& resource, std::string const& tenantId) + const; + }; + +}} // namespace Azure::Identity diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp index 206c77108..ed72b29cf 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp @@ -29,7 +29,6 @@ namespace Azure { namespace Identity { */ struct ClientSecretCredentialOptions final : public Core::Credentials::TokenCredentialOptions { - public: /** * @brief Authentication authority URL. * @note Default value is Azure AD global authority (https://login.microsoftonline.com/). diff --git a/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp index fee63892a..7f0b1e143 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/environment_credential.hpp @@ -20,6 +20,7 @@ namespace Azure { namespace Identity { * */ class EnvironmentCredential final : public Core::Credentials::TokenCredential { + private: std::unique_ptr m_credentialImpl; public: diff --git a/sdk/identity/azure-identity/samples/CMakeLists.txt b/sdk/identity/azure-identity/samples/CMakeLists.txt index 06a1cb0cb..1e56cd640 100644 --- a/sdk/identity/azure-identity/samples/CMakeLists.txt +++ b/sdk/identity/azure-identity/samples/CMakeLists.txt @@ -7,6 +7,11 @@ project (azure-identity-samples LANGUAGES CXX) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED True) +add_executable(azure_cli_credential_sample azure_cli_credential.cpp) +target_link_libraries(azure_cli_credential_sample PRIVATE azure-identity service) +target_include_directories(azure_cli_credential_sample PRIVATE .) +create_per_service_target_build_for_sample(identity azure_cli_credential_sample) + add_executable(chained_token_credential_sample chained_token_credential.cpp) target_link_libraries(chained_token_credential_sample PRIVATE azure-identity service) target_include_directories(chained_token_credential_sample PRIVATE .) diff --git a/sdk/identity/azure-identity/samples/azure_cli_credential.cpp b/sdk/identity/azure-identity/samples/azure_cli_credential.cpp new file mode 100644 index 000000000..93c84f0c2 --- /dev/null +++ b/sdk/identity/azure-identity/samples/azure_cli_credential.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include + +#include + +#include + +int main() +{ + try + { + // Step 1: Initialize Azure CLI Credential. + auto azureCliCredential = std::make_shared(); + + // Step 2: Pass the credential to an Azure Service Client. + Azure::Service::Client azureServiceClient("serviceUrl", azureCliCredential); + + // Step 3: Start using the Azure Service Client. + azureServiceClient.DoSomething(Azure::Core::Context::ApplicationContext); + + std::cout << "Success!" << std::endl; + } + catch (const Azure::Core::Credentials::AuthenticationException& exception) + { + // Step 4: Handle authentication errors, if needed + // (Azure CLI invocation errors or process timeout). + std::cout << "Authentication error: " << exception.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/sdk/identity/azure-identity/src/azure_cli_credential.cpp b/sdk/identity/azure-identity/src/azure_cli_credential.cpp new file mode 100644 index 000000000..c4fe117b4 --- /dev/null +++ b/sdk/identity/azure-identity/src/azure_cli_credential.cpp @@ -0,0 +1,644 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/identity/azure_cli_credential.hpp" + +#include "private/token_credential_impl.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +using Azure::Identity::AzureCliCredential; + +using Azure::DateTime; +using Azure::Core::Context; +using Azure::Core::_internal::Environment; +using Azure::Core::Credentials::AccessToken; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredentialOptions; +using Azure::Core::Credentials::TokenRequestContext; +using Azure::Identity::AzureCliCredentialOptions; +using Azure::Identity::_detail::TokenCredentialImpl; + +namespace { +void ThrowIfNotSafeCmdLineInput(std::string const& input, std::string const& description) +{ + for (auto const c : input) + { + switch (c) + { + case ':': + case '/': + case '.': + case '-': + case '_': + case ' ': + break; + + default: + if (!std::isalnum(c)) + { + throw AuthenticationException( + "AzureCliCredential: Unsafe command line input found in " + description + ": " + + input); + } + } + } +} +} // namespace + +AzureCliCredential::AzureCliCredential( + std::string tenantId, + DateTime::duration cliProcessTimeout, + Core::Credentials::TokenCredentialOptions const& options) + : m_tenantId(std::move(tenantId)), m_cliProcessTimeout(std::move(cliProcessTimeout)) +{ + static_cast(options); + ThrowIfNotSafeCmdLineInput(m_tenantId, "TenantID"); +} + +AzureCliCredential::AzureCliCredential(AzureCliCredentialOptions const& options) + : AzureCliCredential(options.TenantId, options.CliProcessTimeout, options) +{ +} + +AzureCliCredential::AzureCliCredential(TokenCredentialOptions const& options) + : AzureCliCredential( + AzureCliCredentialOptions{}.TenantId, + AzureCliCredentialOptions{}.CliProcessTimeout, + options) +{ +} + +std::string AzureCliCredential::GetAzCommand( + std::string const& resource, + std::string const& tenantId) const +{ + ThrowIfNotSafeCmdLineInput(resource, "Resource"); + std::string command = "az account get-access-token --output json --resource \"" + resource + "\""; + + if (!tenantId.empty()) + { + command += " --tenant \"" + tenantId + "\""; + } + + return command; +} + +namespace { +std::string RunShellCommand( + std::string const& command, + DateTime::duration timeout, + Context const& context); +} + +AccessToken AzureCliCredential::GetToken( + TokenRequestContext const& tokenRequestContext, + Context const& context) const +{ + try + { + auto const azCliResult = RunShellCommand( + GetAzCommand( + TokenCredentialImpl::FormatScopes(tokenRequestContext.Scopes, true, false), m_tenantId), + m_cliProcessTimeout, + context); + + try + { + return TokenCredentialImpl::ParseToken(azCliResult, "accessToken", "expiresIn", "expiresOn"); + } + catch (std::exception const&) + { + // Throw the az command output (error message) + // limited to 250 characters (250 has no special meaning). + throw std::runtime_error(azCliResult.substr(0, 250)); + } + } + catch (std::exception const& e) + { + throw AuthenticationException(std::string("AzureCliCredential::GetToken(): ") + e.what()); + } +} + +namespace { +#if defined(AZ_PLATFORM_WINDOWS) +template struct UniqueHandleHelper; +template <> struct UniqueHandleHelper +{ + static void CloseWin32Handle(HANDLE handle) + { + if (handle != nullptr) + { + static_cast(CloseHandle(handle)); + } + } + + using type = Azure::Core::_internal::BasicUniqueHandle; +}; + +template +using UniqueHandle = Azure::Core::_internal::UniqueHandle; +#endif + +class ShellProcess; +class OutputPipe final { + friend class ShellProcess; + +private: +#if defined(AZ_PLATFORM_WINDOWS) + UniqueHandle m_writeHandle; + UniqueHandle m_readHandle; + OVERLAPPED m_overlapped = {}; +#else + std::vector m_fd; +#endif + + OutputPipe(OutputPipe const&) = delete; + OutputPipe& operator=(OutputPipe const&) = delete; + +public: + OutputPipe(); + + ~OutputPipe(); + + bool NonBlockingRead( + std::vector& buffer, + std::remove_reference::type::size_type& bytesRead, + bool& willHaveMoreData); +}; + +class ShellProcess final { +private: +#if defined(AZ_PLATFORM_WINDOWS) + UniqueHandle m_processHandle; +#else + std::vector m_argv; + std::vector m_argvValues; + + std::vector m_envp; + std::vector m_envpValues; + + posix_spawn_file_actions_t m_actions = {}; + pid_t m_pid = -1; +#endif + + ShellProcess(ShellProcess const&) = delete; + ShellProcess& operator=(ShellProcess const&) = delete; + + void Finalize(); + +public: + ShellProcess(std::string const& command, OutputPipe& outputPipe); + ~ShellProcess() { Finalize(); } + + void Terminate(); +}; + +std::string RunShellCommand( + std::string const& command, + DateTime::duration timeout, + Context const& context) +{ + // Use steady_clock so we're not affected by system time rewinding. + auto const terminateAfter = std::chrono::steady_clock::now() + + std::chrono::duration_cast(timeout); + + std::string output; + + OutputPipe pipe; + ShellProcess shellProcess(command, pipe); + + // Typically token json is just a bit less than 2KiB. + // The best buffer size is the one that lets us to read it in one go. + // (Should it be smaller, we will succeed as well, it'll just take more iterations). + std::vector processOutputBuf(2 * 1024); + + auto willHaveMoreData = true; + do + { + // Check if we should terminate + { + if (context.IsCancelled()) + { + shellProcess.Terminate(); + throw std::runtime_error("Context was cancelled before Azure CLI process was done."); + } + + if (std::chrono::steady_clock::now() > terminateAfter) + { + shellProcess.Terminate(); + throw std::runtime_error("Azure CLI process took too long to complete."); + } + } + + decltype(processOutputBuf)::size_type bytesRead = 0; + if (pipe.NonBlockingRead(processOutputBuf, bytesRead, willHaveMoreData)) + { + output.insert(output.size(), processOutputBuf.data(), bytesRead); + } + else if (willHaveMoreData) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // Value has no special meaning. + } + } while (willHaveMoreData); + + return output; +} + +#if defined(AZ_PLATFORM_WINDOWS) +void ThrowIfApiCallFails(BOOL apiResult, std::string const& errMsg) +{ + // LCOV_EXCL_START + if (!apiResult) + { + throw std::runtime_error( + errMsg + ": " + std::to_string(GetLastError()) + + ); + } + // LCOV_EXCL_STOP +} +#else +void ThrowIfApiCallFails(int apiResult, std::string const& errMsg) +{ + // LCOV_EXCL_START + if (apiResult != 0) + { + throw std::runtime_error( + errMsg + ": " + std::to_string(apiResult) + " (errno: " + std::to_string(errno) + ")"); + } + // LCOV_EXCL_STOP +} +#endif + +OutputPipe::OutputPipe() +{ +#if defined(AZ_PLATFORM_WINDOWS) + SECURITY_ATTRIBUTES pipeSecurity = {}; + pipeSecurity.nLength = sizeof(decltype(pipeSecurity)); + pipeSecurity.bInheritHandle = TRUE; + pipeSecurity.lpSecurityDescriptor = nullptr; + + { + HANDLE readHandle = nullptr; + HANDLE writeHandle = nullptr; + + ThrowIfApiCallFails( + CreatePipe(&readHandle, &writeHandle, &pipeSecurity, 0), "Cannot create output pipe"); + + m_readHandle.reset(readHandle); + m_writeHandle.reset(writeHandle); + } + + ThrowIfApiCallFails( + SetHandleInformation(m_readHandle.get(), HANDLE_FLAG_INHERIT, 0), + "Cannot ensure the read handle for the output pipe is not inherited"); +#else + m_fd.push_back(-1); + m_fd.push_back(-1); + + ThrowIfApiCallFails(pipe(m_fd.data()), "Cannot create output pipe"); + ThrowIfApiCallFails( + fcntl(m_fd[0], F_SETFL, O_NONBLOCK), "Cannot set up output pipe to have non-blocking read"); +#endif +} + +OutputPipe::~OutputPipe() +{ +#if !defined(AZ_PLATFORM_WINDOWS) + for (auto iter = m_fd.rbegin(); iter != m_fd.rend(); ++iter) + { + if (*iter != -1) + { + static_cast(close(*iter)); + } + } +#endif +} + +#if defined(AZ_PLATFORM_WINDOWS) +void AppendToEnvironmentValuesIfNotEmpty( + std::vector& environmentValues, + std::string const& envVarName, + std::string const& value) +{ + if (!value.empty()) // LCOV_EXCL_LINE + { + auto const envVarStatement = envVarName + "=" + value; + + environmentValues.insert( + environmentValues.end(), envVarStatement.begin(), envVarStatement.end()); + + environmentValues.push_back('\0'); // terminate the string + } +} + +void AppendToEnvironmentValuesIfDefined( + std::vector& environmentValues, + std::string const& envVarName) +{ + AppendToEnvironmentValuesIfNotEmpty( + environmentValues, envVarName, Environment::GetVariable(envVarName.c_str())); +} +#else +void AppendToArgvValues( + std::vector& argvValues, + std::vector::type::size_type>& argvValuePositions, + std::string const& value) +{ + argvValuePositions.push_back(argvValues.size()); + argvValues.insert(argvValues.end(), value.begin(), value.end()); + argvValues.push_back('\0'); +} + +void EnsureShellExists(std::string const& pathToShell) +{ + auto file = std::fopen(pathToShell.c_str(), "r"); + + // LCOV_EXCL_START + if (!file) + { + throw std::runtime_error("Cannot locate command line shell."); + } + // LCOV_EXCL_STOP + + std::fclose(file); +} +#endif + +ShellProcess::ShellProcess(std::string const& command, OutputPipe& outputPipe) +{ +#if defined(AZ_PLATFORM_WINDOWS) + // Start the process. + PROCESS_INFORMATION procInfo = {}; + + { + STARTUPINFO startupInfo = {}; + startupInfo.cb = sizeof(decltype(startupInfo)); + startupInfo.dwFlags |= STARTF_USESTDHANDLES; // cspell:disable-line + startupInfo.hStdInput = INVALID_HANDLE_VALUE; + startupInfo.hStdOutput = outputPipe.m_writeHandle.get(); + startupInfo.hStdError = outputPipe.m_writeHandle.get(); + + // Path to cmd.exe + std::vector commandLineStr; + { + auto const commandLine = "cmd /c " + command; + commandLineStr.insert(commandLineStr.end(), commandLine.begin(), commandLine.end()); + commandLineStr.push_back('\0'); + } + + // Form the environment + std::vector environmentValues; + LPVOID lpEnvironment = nullptr; + { + { + constexpr auto PathEnvVarName = "PATH"; + auto pathValue = Environment::GetVariable(PathEnvVarName); + + for (auto const pf : + {Environment::GetVariable("ProgramFiles"), + Environment::GetVariable("ProgramFiles(x86)")}) + { + if (!pf.empty()) // LCOV_EXCL_LINE + { + if (!pathValue.empty()) // LCOV_EXCL_LINE + { + pathValue += ";"; + } + + pathValue += pf + "\\Microsoft SDKs\\Azure\\CLI2\\wbin"; + } + } + + AppendToEnvironmentValuesIfNotEmpty(environmentValues, PathEnvVarName, pathValue); + } + + // Also provide SystemRoot variable. + // Without it, 'az' may fail with the following error: + // "Fatal Python error: _Py_HashRandomization_Init: failed to get random numbers to + // initialize Python + // Python runtime state: preinitialized + // ". + AppendToEnvironmentValuesIfDefined(environmentValues, "SystemRoot"); + + // Also provide USERPROFILE variable. + // Without it, we'll be getting "ERROR: Please run 'az login' to setup account." even if the + // user did log in. + AppendToEnvironmentValuesIfDefined(environmentValues, "USERPROFILE"); + + if (!environmentValues.empty()) // LCOV_EXCL_LINE + { + environmentValues.push_back('\0'); // terminate the block + lpEnvironment = environmentValues.data(); + } + } + + ThrowIfApiCallFails( + CreateProcessA( + nullptr, + commandLineStr.data(), + nullptr, + nullptr, + TRUE, + NORMAL_PRIORITY_CLASS | CREATE_NO_WINDOW, + lpEnvironment, + nullptr, + &startupInfo, + &procInfo), + "Cannot create process"); + } + + // We won't be needing the process main thread handle on our end. + static_cast(CloseHandle(procInfo.hThread)); + + // Keep the process handle so we can cancel it if it takes too long. + m_processHandle.reset(procInfo.hProcess); + + // We won't be writing to the pipe that is meant for the process. + // We will only be reading the pipe. + // So, now that the process is started, we can close write handle on our end. + outputPipe.m_writeHandle.reset(); +#else + // Form the 'argv' array: + // * An array of pointers to non-const C strings (0-terminated). + // * Last element is nullptr. + // * First element (at index 0) is path to a program. + { + // Since the strings that argv is pointing at do need to be non-const, + // and also because each commnd line argument needs to be a separate 0-terminated string, + // We do form all their values in the m_argvValues. + + // Since we append m_argvValues as we go, at one point after insertion it may reallocate the + // buffer to a different address in memory. For that reason, we can't grab addresses before we + // are done forming m_argvValues contents - so until that we record indices where each string + // start - in argvValuePositions. + { + std::vector argvValuePositions; + + // First string is the path to executable, and not the actual first argument. + { + std::string const Shell = "/bin/sh"; + EnsureShellExists(Shell); + AppendToArgvValues(m_argvValues, argvValuePositions, Shell); + } + + // Second argument is the shell switch that tells the command line shell to execute a command + AppendToArgvValues(m_argvValues, argvValuePositions, "-c"); + + // Third value is the command that needs to be executed. + AppendToArgvValues(m_argvValues, argvValuePositions, command); + + // We are done appending to m_argvValues, so it is now safe to grab addresses to the elements + // in it. + for (auto const pos : argvValuePositions) + { + m_argv.push_back(m_argvValues.data() + pos); + } + } + + // argv last element needs to be nullptr. + m_argv.push_back(nullptr); + } + + // Form the 'envp' array: + // * An array of pointers to non-const C strings (0-terminated). + // * Strings are in form key=value (PATH uses ':' as separator) + // * Last element is nullptr. + // * First element (at index 0) is path to a program. + { + auto const actualPathVarValue = Environment::GetVariable("PATH"); + auto const processPathVarStatement = std::string("PATH=") + actualPathVarValue + + (actualPathVarValue.empty() ? "" : ":") + "/usr/bin:/usr/local/bin"; + + m_envpValues.insert( + m_envpValues.end(), processPathVarStatement.begin(), processPathVarStatement.end()); + + m_envpValues.push_back('\0'); + + // We should only grab m_envpValues.data() as we're done appending to it, because appends may + // reallocate the buffer to a different memory location. + m_envp.push_back(m_envpValues.data()); + m_envp.push_back(nullptr); + } + + // Set up pipe communication for the process. + static_cast(posix_spawn_file_actions_init(&m_actions)); + static_cast(posix_spawn_file_actions_addclose(&m_actions, outputPipe.m_fd[0])); + static_cast(posix_spawn_file_actions_adddup2(&m_actions, outputPipe.m_fd[1], 1)); + static_cast(posix_spawn_file_actions_addclose(&m_actions, outputPipe.m_fd[1])); + + { + auto const spawnResult + = posix_spawn(&m_pid, m_argv[0], &m_actions, NULL, m_argv.data(), m_envp.data()); + + // LCOV_EXCL_START + if (spawnResult != 0) + { + m_pid = -1; + Finalize(); + ThrowIfApiCallFails(spawnResult, "Cannot spawn process"); + } + // LCOV_EXCL_STOP + } + + close(outputPipe.m_fd[1]); + outputPipe.m_fd[1] = -1; +#endif +} + +void ShellProcess::Finalize() +{ +#if !defined(AZ_PLATFORM_WINDOWS) + if (m_pid > 0) + { + static_cast(waitpid(m_pid, nullptr, 0)); + } + + posix_spawn_file_actions_destroy(&m_actions); +#endif +} + +void ShellProcess::Terminate() +{ +#if defined(AZ_PLATFORM_WINDOWS) + static_cast(TerminateProcess(m_processHandle.get(), 0)); +#else + if (m_pid > 0) + { + static_cast(kill(m_pid, SIGKILL)); + } +#endif +} + +bool OutputPipe::NonBlockingRead( + std::vector& buffer, + std::remove_reference::type::size_type& bytesRead, + bool& willHaveMoreData) +{ +#if defined(AZ_PLATFORM_WINDOWS) + static_assert( + sizeof(std::remove_reference::type::value_type) == sizeof(CHAR), + "buffer elements and CHARs should be of the same size"); + + // Since we're using OVERLAPPED, call to ReadFile() is non-blocking - ReadFile() would return + // immediately if there is no data, and won't wait for any data to arrive. + DWORD bytesReadDword = 0; + auto const hadData + = (ReadFile( + m_readHandle.get(), + buffer.data(), + static_cast(buffer.size()), + &bytesReadDword, + &m_overlapped) + == TRUE); + + bytesRead = static_cast::type>(bytesReadDword); + + // Invoking code should be calling this function until we set willHaveMoreData to true. + // We set it to true when we receive ERROR_BROKEN_PIPE after ReadFile(), which means the process + // has finished and closed the pipe on its end, and it means there won't be more data after + // what've just read. + willHaveMoreData = (GetLastError() != ERROR_BROKEN_PIPE); + + return hadData && bytesRead > 0; +#else + static_assert( + sizeof(std::remove_reference::type::value_type) == sizeof(char), + "buffer elements and chars should be of the same size"); + + auto const nread = read(m_fd[0], buffer.data(), static_cast(buffer.size())); + + bytesRead = static_cast::type>(nread < 0 ? 0 : nread); + willHaveMoreData = (nread > 0 || (nread == -1 && errno == EAGAIN)); + return nread > 0; +#endif +} +} // namespace diff --git a/sdk/identity/azure-identity/src/client_secret_credential.cpp b/sdk/identity/azure-identity/src/client_secret_credential.cpp index 78cdac88a..0cb63c787 100644 --- a/sdk/identity/azure-identity/src/client_secret_credential.cpp +++ b/sdk/identity/azure-identity/src/client_secret_credential.cpp @@ -105,10 +105,7 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken( auto request = std::make_unique( HttpMethod::Post, m_requestUrl, body.str()); - if (m_isAdfs) - { - request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); - } + request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); return request; }); diff --git a/sdk/identity/azure-identity/src/private/token_cache_internals.hpp b/sdk/identity/azure-identity/src/private/token_cache_internals.hpp index d3d67b30f..b1e612ce9 100644 --- a/sdk/identity/azure-identity/src/private/token_cache_internals.hpp +++ b/sdk/identity/azure-identity/src/private/token_cache_internals.hpp @@ -10,6 +10,10 @@ #include "token_cache.hpp" +#if defined(TESTING_BUILD) +#include "azure/identity/dll_import_export.hpp" +#endif + #include #include @@ -76,13 +80,13 @@ namespace Azure { namespace Identity { namespace _detail { * A test hook that gets invoked before cache write lock gets acquired. * */ - static std::function OnBeforeCacheWriteLock; + AZ_IDENTITY_DLLEXPORT static std::function OnBeforeCacheWriteLock; /** * A test hook that gets invoked before item write lock gets acquired. * */ - static std::function OnBeforeItemWriteLock; + AZ_IDENTITY_DLLEXPORT static std::function OnBeforeItemWriteLock; #endif }; }}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/private/token_credential_impl.hpp b/sdk/identity/azure-identity/src/private/token_credential_impl.hpp index eb3e4c6a4..0de06d9e2 100644 --- a/sdk/identity/azure-identity/src/private/token_credential_impl.hpp +++ b/sdk/identity/azure-identity/src/private/token_credential_impl.hpp @@ -48,17 +48,44 @@ namespace Azure { namespace Identity { namespace _detail { * * @param scopes Authentication scopes. * @param asResource `true` if \p scopes need to be formatted as a resource. + * @param urlEncode `true` if the result needs to be URL-encoded. * * @return A string representing scopes so that it can be used in Identity request. * * @note Does not check for \p scopes being empty. */ - static std::string FormatScopes(std::vector const& scopes, bool asResource); + static std::string FormatScopes( + std::vector const& scopes, + bool asResource, + bool urlEncode = true); + + /** + * @brief Parses JSON that contains access token and its expiration. + * + * @param jsonString String with a JSON object to parse. + * @param accessTokenPropertyName Name of a property in the JSON object that represents access + * token. + * @param expiresInPropertyName Name of a property in the JSON object that represents token + * expiration in number of seconds from now. + * @param expiresOnPropertyName Name of a property in the JSON object that represents token + * expiration as absolute date-time stamp. Can be empty, in which case no attempt to parse it is + * made. + * + * @return A successfully parsed access token. + * + * @throw `std::exception` if there was a problem parsing the token. + */ + static Core::Credentials::AccessToken ParseToken( + std::string const& jsonString, + std::string const& accessTokenPropertyName, + std::string const& expiresInPropertyName, + std::string const& expiresOnPropertyName); /** * @brief Holds `#Azure::Core::Http::Request` and all the associated resources for the HTTP - * request body, so that the lifetime for all the resources needed for the request aligns with - * its lifetime, and so that instances of this class can easily be returned from a function. + * request body, so that the lifetime for all the resources needed for the request aligns + * with its lifetime, and so that instances of this class can easily be returned from a + * function. * */ class TokenRequest final { diff --git a/sdk/identity/azure-identity/src/token_credential_impl.cpp b/sdk/identity/azure-identity/src/token_credential_impl.cpp index a58452df2..3c5a8ffc8 100644 --- a/sdk/identity/azure-identity/src/token_credential_impl.cpp +++ b/sdk/identity/azure-identity/src/token_credential_impl.cpp @@ -8,21 +8,37 @@ #include "private/package_version.hpp" #include -#include +#include -using namespace Azure::Identity::_detail; +using Azure::Identity::_detail::TokenCredentialImpl; -TokenCredentialImpl::TokenCredentialImpl(Core::Credentials::TokenCredentialOptions const& options) +using Azure::Identity::_detail::PackageVersion; + +using Azure::Core::Context; +using Azure::Core::Url; +using Azure::Core::Credentials::AccessToken; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredentialOptions; +using Azure::Core::Http::HttpStatusCode; +using Azure::Core::Http::RawResponse; + +TokenCredentialImpl::TokenCredentialImpl(TokenCredentialOptions const& options) : m_httpPipeline(options, "identity", PackageVersion::ToString(), {}, {}) { } +namespace { +std::string OptionalUrlEncode(std::string const& value, bool doEncode) +{ + return doEncode ? Url::Encode(value) : value; +} +} // namespace + std::string TokenCredentialImpl::FormatScopes( std::vector const& scopes, - bool asResource) + bool asResource, + bool urlEncode) { - using Azure::Core::Url; - if (asResource && scopes.size() == 1) { auto resource = scopes[0]; @@ -37,34 +53,39 @@ std::string TokenCredentialImpl::FormatScopes( resource = resource.substr(0, resourceLen - suffixLen); } - return Url::Encode(resource); + return OptionalUrlEncode(resource, urlEncode); } - auto scopesIter = scopes.begin(); - auto scopesStr = Azure::Core::Url::Encode(*scopesIter); - - auto const scopesEnd = scopes.end(); - for (++scopesIter; scopesIter != scopesEnd; ++scopesIter) + std::string scopesStr; { - scopesStr += std::string(" ") + Url::Encode(*scopesIter); + auto scopesIter = scopes.begin(); + auto const scopesEnd = scopes.end(); + + if (scopesIter != scopesEnd) // LCOV_EXCL_LINE + { + auto const scope = *scopesIter; + scopesStr += OptionalUrlEncode(scope, urlEncode); + } + + for (++scopesIter; scopesIter != scopesEnd; ++scopesIter) + { + auto const Separator = std::string(" "); // Element separator never gets URL-encoded + + auto const scope = *scopesIter; + scopesStr += Separator + OptionalUrlEncode(scope, urlEncode); + } } return scopesStr; } -Azure::Core::Credentials::AccessToken TokenCredentialImpl::GetToken( - Core::Context const& context, +AccessToken TokenCredentialImpl::GetToken( + Context const& context, std::function()> const& createRequest, std::function( - Azure::Core::Http::HttpStatusCode statusCode, - Azure::Core::Http::RawResponse const& response)> const& shouldRetry) const + HttpStatusCode statusCode, + RawResponse const& response)> const& shouldRetry) const { - using Azure::Core::Credentials::AuthenticationException; - using Azure::Core::Http::HttpStatusCode; - using Azure::Core::Http::RawResponse; - - static std::string const errorMsgPrefix("GetToken: "); - try { std::unique_ptr response; @@ -75,7 +96,7 @@ Azure::Core::Credentials::AccessToken TokenCredentialImpl::GetToken( response = m_httpPipeline.Send(request->HttpRequest, context); if (!response) { - throw AuthenticationException(errorMsgPrefix + "null response"); + throw std::runtime_error("null response"); } auto const statusCode = response->GetStatusCode(); @@ -87,12 +108,11 @@ Azure::Core::Credentials::AccessToken TokenCredentialImpl::GetToken( request = shouldRetry(statusCode, *response); if (request == nullptr) { - std::ostringstream errorMsg; - errorMsg << errorMsgPrefix << "error response: " - << static_cast::type>(statusCode) << " " - << response->GetReasonPhrase(); - - throw AuthenticationException(errorMsg.str()); + throw std::runtime_error( + std::string("error response: ") + + std::to_string( + static_cast::type>(statusCode)) + + " " + response->GetReasonPhrase()); } response.reset(); @@ -100,79 +120,12 @@ Azure::Core::Credentials::AccessToken TokenCredentialImpl::GetToken( } auto const& responseBodyVector = response->GetBody(); - std::string responseBody(responseBodyVector.begin(), responseBodyVector.end()); - // TODO: use JSON parser. - auto const responseBodySize = responseBody.size(); - - static std::string const jsonExpiresIn = "expires_in"; - static std::string const jsonAccessToken = "access_token"; - - auto responseBodyPos = responseBody.find(':', responseBody.find(jsonExpiresIn)); - if (responseBodyPos == std::string::npos) - { - std::ostringstream errorMsg; - errorMsg << errorMsgPrefix << "response json: \'" << jsonExpiresIn << "\' not found."; - - throw AuthenticationException(errorMsg.str()); - } - - for (; responseBodyPos < responseBodySize; ++responseBodyPos) - { - auto c = responseBody[responseBodyPos]; - if (c != ':' && c != ' ' && c != '\"' && c != '\'') - { - break; - } - } - - long long expiresInSeconds = 0; - for (; responseBodyPos < responseBodySize; ++responseBodyPos) - { - auto c = responseBody[responseBodyPos]; - if (c < '0' || c > '9') - { - break; - } - - expiresInSeconds = (expiresInSeconds * 10) + (static_cast(c) - '0'); - } - - responseBodyPos = responseBody.find(':', responseBody.find(jsonAccessToken)); - if (responseBodyPos == std::string::npos) - { - std::ostringstream errorMsg; - errorMsg << errorMsgPrefix << "response json: \'" << jsonAccessToken << "\' not found."; - - throw AuthenticationException(errorMsg.str()); - } - - for (; responseBodyPos < responseBodySize; ++responseBodyPos) - { - auto c = responseBody[responseBodyPos]; - if (c != ':' && c != ' ' && c != '\"' && c != '\'') - { - break; - } - } - - auto const tokenBegin = responseBodyPos; - for (; responseBodyPos < responseBodySize; ++responseBodyPos) - { - auto c = responseBody[responseBodyPos]; - if (c == '\"' || c == '\'') - { - break; - } - } - auto const tokenEnd = responseBodyPos; - - auto const responseBodyBegin = responseBody.begin(); - - return { - std::string(responseBodyBegin + tokenBegin, responseBodyBegin + tokenEnd), - std::chrono::system_clock::now() + std::chrono::seconds(expiresInSeconds), - }; + return ParseToken( + std::string(responseBodyVector.begin(), responseBodyVector.end()), + "access_token", + "expires_in", + std::string()); } catch (AuthenticationException const&) { @@ -180,10 +133,168 @@ Azure::Core::Credentials::AccessToken TokenCredentialImpl::GetToken( } catch (std::exception const& e) { - throw AuthenticationException(e.what()); - } - catch (...) - { - throw AuthenticationException("unknown error"); + throw AuthenticationException(std::string("GetToken(): ") + e.what()); } } + +namespace { +[[noreturn]] void ThrowMissingJsonPropertyError(std::string const& propertyName) +{ + throw std::runtime_error( + std::string("Token JSON object: \'") + propertyName + "\' property was not found."); +} + +bool GetPropertyValueAsInt64( + std::string const& jsonString, + std::string const& propertyName, + int64_t& outValue); + +bool GetPropertyValueAsString( + std::string const& jsonString, + std::string const& propertyName, + std::string& outValue); +} // namespace + +AccessToken TokenCredentialImpl::ParseToken( + std::string const& jsonString, + std::string const& accessTokenPropertyName, + std::string const& expiresInPropertyName, + std::string const& expiresOnPropertyName) +{ + // TODO: use JSON parser. + AccessToken accessToken; + if (!GetPropertyValueAsString(jsonString, accessTokenPropertyName, accessToken.Token)) + { + ThrowMissingJsonPropertyError(accessTokenPropertyName); + } + + int64_t expiresIn = 0; + if (GetPropertyValueAsInt64(jsonString, expiresInPropertyName, expiresIn)) + { + accessToken.ExpiresOn = std::chrono::system_clock::now() + std::chrono::seconds(expiresIn); + return accessToken; + } + + if (expiresOnPropertyName.empty()) + { + ThrowMissingJsonPropertyError(expiresInPropertyName); + } + + std::string expiresOn; + if (!GetPropertyValueAsString(jsonString, expiresOnPropertyName, expiresOn)) + { + ThrowMissingJsonPropertyError(expiresInPropertyName + "\' or \'" + expiresOnPropertyName); + } + + { + auto const spacePos = expiresOn.find(' '); + if (spacePos != std::string::npos) // LCOV_EXCL_LINE + { + expiresOn = expiresOn.replace(spacePos, 1, 1, 'T'); + } + } + + accessToken.ExpiresOn = Azure::DateTime::Parse(expiresOn, Azure::DateTime::DateFormat::Rfc3339); + return accessToken; +} + +namespace { +std::string::size_type GetPropertyValueStart( + std::string const& jsonString, + std::string const& propertyName); + +bool GetPropertyValueAsInt64( + std::string const& jsonString, + std::string const& propertyName, + int64_t& outValue) +{ + auto const valueStartPos = GetPropertyValueStart(jsonString, propertyName); + if (valueStartPos == std::string::npos) + { + return false; + } + + int64_t value = 0; + { + auto const size = jsonString.size(); + for (auto pos = valueStartPos; pos < size; ++pos) + { + auto c = jsonString[pos]; + if (c < '0' || c > '9') + { + break; + } + + value = (value * 10) + (static_cast(c) - '0'); + } + } + + outValue = value; + + return true; +} + +std::string::size_type GetPropertyValueEnd(std::string const& str, std::string::size_type startPos); + +bool GetPropertyValueAsString( + std::string const& jsonString, + std::string const& propertyName, + std::string& outValue) +{ + auto const valueStartPos = GetPropertyValueStart(jsonString, propertyName); + if (valueStartPos == std::string::npos) + { + return false; + } + auto const jsonStringBegin = jsonString.begin(); + outValue = std::string( + jsonStringBegin + valueStartPos, + jsonStringBegin + GetPropertyValueEnd(jsonString, valueStartPos)); + + return true; +} + +std::string::size_type GetPropertyValueStart( + std::string const& jsonString, + std::string const& propertyName) +{ + auto const propertyNameValueSeparator = jsonString.find(':', jsonString.find(propertyName)); + if (propertyNameValueSeparator == std::string::npos) + { + return std::string::npos; + } + + auto pos = propertyNameValueSeparator; + { + auto const jsonStringSize = jsonString.size(); + for (; pos < jsonStringSize; ++pos) + { + auto c = jsonString[pos]; + if (c != ':' && c != ' ' && c != '\"' && c != '\'') + { + break; + } + } + } + + return pos; +} + +std::string::size_type GetPropertyValueEnd(std::string const& str, std::string::size_type startPos) +{ + auto pos = startPos; + { + auto const strSize = str.size(); + for (; pos < strSize; ++pos) + { + auto c = str[pos]; + if (c == '\"' || c == '\'') + { + break; + } + } + } + + return pos; +} +} // namespace diff --git a/sdk/identity/azure-identity/test/ut/CMakeLists.txt b/sdk/identity/azure-identity/test/ut/CMakeLists.txt index 2b85e8216..a87255464 100644 --- a/sdk/identity/azure-identity/test/ut/CMakeLists.txt +++ b/sdk/identity/azure-identity/test/ut/CMakeLists.txt @@ -16,6 +16,7 @@ add_compile_definitions(AZURE_TEST_RECORDING_DIR="${CMAKE_CURRENT_LIST_DIR}") add_executable ( azure-identity-test + azure_cli_credential_test.cpp chained_token_credential_test.cpp client_certificate_credential_test.cpp client_secret_credential_test.cpp diff --git a/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp new file mode 100644 index 000000000..982d39f07 --- /dev/null +++ b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/identity/azure_cli_credential.hpp" + +#include + +#include +#include +#include +#include + +#include + +using Azure::Identity::AzureCliCredential; + +using Azure::DateTime; +using Azure::Core::Context; +using Azure::Core::Credentials::AuthenticationException; +using Azure::Core::Credentials::TokenCredentialOptions; +using Azure::Core::Credentials::TokenRequestContext; +using Azure::Identity::AzureCliCredentialOptions; + +namespace { +constexpr auto InfiniteCommand = +#if defined(AZ_PLATFORM_WINDOWS) + "for /l %q in (0) do timeout 10"; +#else + "while true; do sleep 10; done" +#endif +; + +constexpr auto EmptyOutputCommand = +#if defined(AZ_PLATFORM_WINDOWS) + "rem"; +#else + "clear" +#endif +; + +std::string EchoCommand(std::string const text) +{ +#if defined(AZ_PLATFORM_WINDOWS) + return std::string("echo ") + text; +#else + return std::string("echo \'") + text + "\'"; +#endif +} + +class AzureCliTestCredential : public AzureCliCredential { +private: + std::string m_command; + + std::string GetAzCommand(std::string const& resource, std::string const& tenantId) const override + { + static_cast(resource); + static_cast(tenantId); + + return m_command; + } + +public: + explicit AzureCliTestCredential(std::string command) : m_command(std::move(command)) {} + + explicit AzureCliTestCredential(std::string command, AzureCliCredentialOptions const& options) + : AzureCliCredential(options), m_command(std::move(command)) + { + } + + explicit AzureCliTestCredential(std::string command, TokenCredentialOptions const& options) + : AzureCliCredential(options), m_command(std::move(command)) + { + } + + std::string GetOriginalAzCommand(std::string const& resource, std::string const& tenantId) const + { + return AzureCliCredential::GetAzCommand(resource, tenantId); + } + + decltype(m_tenantId) const& GetTenantId() const { return m_tenantId; } + decltype(m_cliProcessTimeout) const& GetCliProcessTimeout() const { return m_cliProcessTimeout; } +}; +} // namespace + +TEST(AzureCliCredential, Success) +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," + "\"tenant\":\"72f988bf-86f1-41af-91ab-2d7cd011db47\"," + "\"tokenType\":\"Bearer\"}"; + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, Error) +{ + AzureCliTestCredential const azCliCred( + EchoCommand("ERROR: Please run 'az login' to setup account.")); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); +} + +TEST(AzureCliCredential, EmptyOutput) +{ + AzureCliTestCredential const azCliCred(EmptyOutputCommand); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); +} + +TEST(AzureCliCredential, BigToken) +{ + std::string accessToken; + { + std::string const tokenPart = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + auto const nIterations = ((4 * 1024) / tokenPart.size()) + 1; + for (auto i = 0; i < static_cast(nIterations); ++i) + { + accessToken += tokenPart; + } + } + + AzureCliTestCredential const azCliCred(EchoCommand( + std::string("{\"accessToken\":\"") + accessToken + + "\",\"expiresOn\":\"2022-08-24 00:43:08.000000\"}")); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, accessToken); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, ExpiresIn) +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\",\"expiresIn\":30}"; + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const timestampBefore = std::chrono::system_clock::now(); + auto const token = azCliCred.GetToken(trc, {}); + auto const timestampAfter = std::chrono::system_clock::now(); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_GE(token.ExpiresOn, timestampBefore + std::chrono::seconds(30)); + EXPECT_LE(token.ExpiresOn, timestampAfter + std::chrono::seconds(30)); +} + +TEST(AzureCliCredential, TimedOut) +{ + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::seconds(2); + AzureCliTestCredential const azCliCred(InfiniteCommand, options); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); +} + +TEST(AzureCliCredential, ContextCancelled) +{ + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliTestCredential const azCliCred(InfiniteCommand, options); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto context = Context::ApplicationContext.WithDeadline( + std::chrono::system_clock::now() + std::chrono::hours(24)); + + std::atomic thread1Started(false); + + std::thread thread1([&]() { + thread1Started = true; + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, context)), AuthenticationException); + }); + + std::thread thread2([&]() { + while (!thread1Started) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + std::this_thread::sleep_for(std::chrono::seconds(2)); + context.Cancel(); + }); + + thread1.join(); + thread2.join(); +} + +TEST(AzureCliCredential, Defaults) +{ + { + AzureCliCredentialOptions const DefaultOptions; + + { + AzureCliTestCredential azCliCred({}); + EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + } + + { + TokenCredentialOptions const options; + AzureCliTestCredential azCliCred({}, options); + EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + } + } + + { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.CliProcessTimeout = std::chrono::seconds(12345); + + AzureCliTestCredential azCliCred({}, options); + + EXPECT_EQ(azCliCred.GetTenantId(), "01234567-89AB-CDEF-0123-456789ABCDEF"); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), std::chrono::seconds(12345)); + } +} + +TEST(AzureCliCredential, CmdLine) +{ + AzureCliTestCredential azCliCred({}); + + auto const cmdLineWithoutTenant = azCliCred.GetOriginalAzCommand("https://storage.azure.com", {}); + + auto const cmdLineWithTenant = azCliCred.GetOriginalAzCommand( + "https://storage.azure.com", "01234567-89AB-CDEF-0123-456789ABCDEF"); + + EXPECT_EQ( + cmdLineWithoutTenant, + "az account get-access-token --output json --resource \"https://storage.azure.com\""); + + EXPECT_EQ( + cmdLineWithTenant, + "az account get-access-token --output json --resource \"https://storage.azure.com\"" + " --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\""); +} + +TEST(AzureCliCredential, UnsafeChars) +{ + std::string const Exploit = std::string("\" | echo OWNED | ") + InfiniteCommand + " | echo \""; + + { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.TenantId += Exploit; + + EXPECT_THROW( + static_cast(std::make_unique(options)), AuthenticationException); + } + + { + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Exploit); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + } +} diff --git a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp index d2693f7ab..fb0670842 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp @@ -180,8 +180,7 @@ TEST(TokenCredentialImpl, ThrowInt) {"{\"expires_in\":3600, \"access_token\":\"ACCESSTOKEN\"}"}, [](auto& credential, auto& tokenRequestContext, auto& context) { AccessToken token; - EXPECT_THROW( - token = credential.GetToken(tokenRequestContext, context), AuthenticationException); + EXPECT_THROW(token = credential.GetToken(tokenRequestContext, context), int); return token; })); }