diff --git a/cmake-modules/AzureDoxygen.cmake b/cmake-modules/AzureDoxygen.cmake index 2a410e038..31b2e2ffd 100644 --- a/cmake-modules/AzureDoxygen.cmake +++ b/cmake-modules/AzureDoxygen.cmake @@ -26,14 +26,6 @@ function(generate_documentation PROJECT_NAME PROJECT_VERSION) # classes and enums directly into the documentation. set(DOXYGEN_INLINE_SOURCES NO) set(DOXYGEN_MARKDOWN_ID_STYLE GITHUB) - # Used to correctly expand macros like _azure_NON_FINAL_FOR_TESTS when generating docs. - # Using EXPAND_ONLY_PREDEF to limit macro expansion to the macros specified with the PREDEFINED tags. - set(DOXYGEN_MACRO_EXPANSION YES) - set(EXPAND_ONLY_PREDEF YES) - set(DOXYGEN_PREDEFINED - _azure_NON_FINAL_FOR_TESTS=final - _azure_VIRTUAL_FOR_TESTS= - ) # Skip generating docs for json, test, samples, and private files. set(DOXYGEN_EXCLUDE_PATTERNS json.hpp diff --git a/sdk/core/azure-core-amqp/CMakeLists.txt b/sdk/core/azure-core-amqp/CMakeLists.txt index af4d71c96..96c4ccce6 100644 --- a/sdk/core/azure-core-amqp/CMakeLists.txt +++ b/sdk/core/azure-core-amqp/CMakeLists.txt @@ -192,7 +192,7 @@ az_rtti_setup( if(BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/claims_based_security.hpp b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/claims_based_security.hpp index 8c4c6c0f9..dcb82ee6b 100644 --- a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/claims_based_security.hpp +++ b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/claims_based_security.hpp @@ -37,7 +37,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { Jwt, }; -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) /** @brief Implementation of AMQP 1.0 Claims-based Security (CBS) protocol. * * This class allows AMQP clients to implement the CBS protocol for authentication and @@ -74,5 +74,5 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { private: std::shared_ptr m_impl; }; -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD }}}} // namespace Azure::Core::Amqp::_detail diff --git a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/connection.hpp b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/connection.hpp index 1c72489ce..5798add6b 100644 --- a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/connection.hpp +++ b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/connection.hpp @@ -21,7 +21,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { class ConnectionFactory; }}}} // namespace Azure::Core::Amqp::_detail -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // Define the test classes dependant on this class here. namespace Azure { namespace Core { namespace Amqp { namespace Tests { namespace MessageTests { @@ -44,7 +44,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace Tests { class TestMessages_ReceiverReceiveAsync_Test; }}}} // namespace Azure::Core::Amqp::Tests -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD #if defined(SAMPLES_BUILD) namespace LocalServerSample { int LocalServerSampleMain(); @@ -452,7 +452,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _internal { std::shared_ptr<_detail::ConnectionImpl> m_impl; friend class _detail::ConnectionFactory; -#if _azure_TESTING_BUILD +#if TESTING_BUILD friend class Azure::Core::Amqp::Tests::MessageTests::AmqpServerMock; friend class Azure::Core::Amqp::Tests::MessageTests::MessageListenerEvents; friend class Azure::Core::Amqp::Tests::TestSocketListenerEvents; @@ -467,7 +467,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _internal { friend class Azure::Core::Amqp::Tests::TestMessages_SenderSendAsync_Test; friend class Azure::Core::Amqp::Tests::TestMessages_SenderOpenClose_Test; -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD #if SAMPLES_BUILD friend int LocalServerSample::LocalServerSampleMain(); #endif // SAMPLES_BUILD diff --git a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/link.hpp b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/link.hpp index 137984068..6d93bb18e 100644 --- a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/link.hpp +++ b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/link.hpp @@ -70,7 +70,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { class LinkEvents { public: virtual Models::AmqpValue OnTransferReceived( -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) Link const& link, #else std::shared_ptr link, @@ -80,7 +80,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { const unsigned char* payloadBytes) = 0; virtual void OnLinkStateChanged( -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) Link const& link, #else std::shared_ptr link, @@ -89,7 +89,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { LinkState previousLinkState) = 0; virtual void OnLinkFlowOn( -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) Link const& link #else std::shared_ptr link @@ -99,7 +99,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { virtual ~LinkEvents() = default; }; -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) class Link final { public: @@ -172,5 +172,5 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { std::shared_ptr m_impl; }; -#endif // defined(_azure_TESTING_BUILD) +#endif // defined(TESTING_BUILD) }}}} // namespace Azure::Core::Amqp::_detail diff --git a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/session.hpp b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/session.hpp index 93c7e4005..572bfeabf 100644 --- a/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/session.hpp +++ b/sdk/core/azure-core-amqp/inc/azure/core/amqp/internal/session.hpp @@ -16,7 +16,7 @@ #include #include -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // Define the test classes dependant on this class here. namespace Azure { namespace Core { namespace Amqp { namespace Tests { namespace MessageTests { @@ -34,7 +34,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace Tests { class LinkSocketListenerEvents; class TestMessages_SenderSendAsync_Test; }}}} // namespace Azure::Core::Amqp::Tests -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD #if defined(SAMPLES_BUILD) namespace LocalServerSample { class SampleEvents; @@ -245,7 +245,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _internal { friend class _detail::SessionFactory; -#if _azure_TESTING_BUILD +#if TESTING_BUILD friend class Azure::Core::Amqp::Tests::MessageTests::AmqpServerMock; friend class Azure::Core::Amqp::Tests::MessageTests::MockServiceEndpoint; friend class Azure::Core::Amqp::Tests::MessageTests::MessageListenerEvents; @@ -258,7 +258,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _internal { friend class Azure::Core::Amqp::Tests::TestLinks_LinkAttachDetach_Test; friend class Azure::Core::Amqp::Tests::TestMessages_SenderSendAsync_Test; -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD #if SAMPLES_BUILD friend class LocalServerSample::SampleEvents; #endif // SAMPLES_BUILD diff --git a/sdk/core/azure-core-amqp/src/amqp/claim_based_security.cpp b/sdk/core/azure-core-amqp/src/amqp/claim_based_security.cpp index 7cfe8e102..35e78cfb3 100644 --- a/sdk/core/azure-core-amqp/src/amqp/claim_based_security.cpp +++ b/sdk/core/azure-core-amqp/src/amqp/claim_based_security.cpp @@ -16,7 +16,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { using namespace Azure::Core::Amqp::_internal; // The non-Impl types for CBS exist only for testing purposes. -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) ClaimsBasedSecurity::ClaimsBasedSecurity(Session const& session) : m_impl{std::make_shared<_detail::ClaimsBasedSecurityImpl>(SessionFactory::GetImpl(session))} { @@ -37,7 +37,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { return m_impl->PutToken(tokenType, audience, token, context); } -#endif // _azure_TESTING_BUILD +#endif // TESTING_BUILD ClaimsBasedSecurityImpl::ClaimsBasedSecurityImpl(std::shared_ptr<_detail::SessionImpl> session) : m_session{session} diff --git a/sdk/core/azure-core-amqp/src/amqp/link.cpp b/sdk/core/azure-core-amqp/src/amqp/link.cpp index 5cb0ec473..44a05d594 100644 --- a/sdk/core/azure-core-amqp/src/amqp/link.cpp +++ b/sdk/core/azure-core-amqp/src/amqp/link.cpp @@ -15,7 +15,7 @@ #include namespace Azure { namespace Core { namespace Amqp { namespace _detail { -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) Link::Link( _internal::Session const& session, std::string const& name, @@ -475,7 +475,7 @@ namespace Azure { namespace Core { namespace Amqp { namespace _detail { { return Models::_detail::AmqpValueFactory::ToUamqp(link->m_eventHandler->OnTransferReceived( -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) Link{link->shared_from_this()}, #else link->shared_from_this(), diff --git a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt index 7c858a11a..85aa4f5f3 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt +++ b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt @@ -96,7 +96,7 @@ endif() if(BUILD_AZURE_CORE_TRACING_OPENTELEMETRY AND BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index 68b986f38..24fff3915 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -96,7 +96,6 @@ set( inc/azure/core/internal/json/json_optional.hpp inc/azure/core/internal/json/json_serializable.hpp inc/azure/core/internal/strings.hpp - inc/azure/core/internal/test_hooks.hpp inc/azure/core/internal/tracing/service_tracing.hpp inc/azure/core/internal/tracing/tracing_impl.hpp inc/azure/core/internal/unique_handle.hpp @@ -205,7 +204,7 @@ az_rtti_setup( if(BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core/inc/azure/core/http/http.hpp b/sdk/core/azure-core/inc/azure/core/http/http.hpp index 373d55bb2..8d191b62f 100644 --- a/sdk/core/azure-core/inc/azure/core/http/http.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/http.hpp @@ -28,7 +28,7 @@ #include #include -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // Define the class used from tests to validate retry enabled namespace Azure { namespace Core { namespace Test { class TestHttp_getters_Test; @@ -181,7 +181,7 @@ namespace Azure { namespace Core { namespace Http { */ class Request final { friend class Azure::Core::Http::Policies::_internal::RetryPolicy; -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // make tests classes friends to validate set Retry friend class Azure::Core::Test::TestHttp_getters_Test; friend class Azure::Core::Test::TestHttp_query_parameter_Test; diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index 447319b1a..f9e4ec1ab 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -16,7 +16,6 @@ #include "azure/core/http/transport.hpp" #include "azure/core/internal/http/http_sanitizer.hpp" #include "azure/core/internal/http/user_agent.hpp" -#include "azure/core/internal/test_hooks.hpp" #include "azure/core/uuid.hpp" #include @@ -31,14 +30,6 @@ #include #include -#if defined(_azure_TESTING_BUILD) -// Define the class used from tests to validate retry policy -namespace Azure { namespace Core { namespace Test { - class RetryPolicyTest; - class RetryLogic; -}}} // namespace Azure::Core::Test -#endif - /** * A function that should be implemented and linked to the end-user application in order to override * an HTTP transport implementation provided by Azure SDK with custom implementation. @@ -372,13 +363,11 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { /** * @brief HTTP retry policy. */ - class RetryPolicy _azure_NON_FINAL_FOR_TESTS : public HttpPolicy { - -#if defined(_azure_TESTING_BUILD) - friend class Azure::Core::Test::RetryPolicyTest; - friend class Azure::Core::Test::RetryLogic; + class RetryPolicy +#if !defined(TESTING_BUILD) + final #endif - + : public HttpPolicy { private: RetryOptions m_retryOptions; @@ -413,14 +402,14 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { */ static int32_t GetRetryCount(Context const& context); - private: - _azure_VIRTUAL_FOR_TESTS bool ShouldRetryOnTransportFailure( + protected: + virtual bool ShouldRetryOnTransportFailure( RetryOptions const& retryOptions, int32_t attempt, std::chrono::milliseconds& retryAfter, double jitterFactor = -1) const; - _azure_VIRTUAL_FOR_TESTS bool ShouldRetryOnResponse( + virtual bool ShouldRetryOnResponse( RawResponse const& response, RetryOptions const& retryOptions, int32_t attempt, diff --git a/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp b/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp deleted file mode 100644 index c2f555f51..000000000 --- a/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -/** - * @file - * @brief This file is used to define internal-only macros that are used to control the behavior of - * the Azure SDK when running tests, to allow mocking within tests. - * The macros in this file should NOT be used by anyone outside this repo. - */ - -#pragma once - -// When testing is enabled, we want to make sure that certain classes are not final, so that we can -// mock it. -#if defined(_azure_TESTING_BUILD) - -/** - * @brief If we are testing, we want to make sure that classes are not final, by default. - */ -#if !defined(_azure_NON_FINAL_FOR_TESTS) -#define _azure_NON_FINAL_FOR_TESTS -#endif - -/** - * @brief If we are testing, we want to make sure methods can be made virtual, for mocking. - */ -#if !defined(_azure_VIRTUAL_FOR_TESTS) -#define _azure_VIRTUAL_FOR_TESTS virtual -#endif - -#else - -/** - * @brief If we are not testing, we want to make sure that classes are final, by default. - */ -#if !defined(_azure_NON_FINAL_FOR_TESTS) -#define _azure_NON_FINAL_FOR_TESTS final -#endif - -/** - * @brief If we are not testing, we don't need to make methods virtual for mocking. - */ -#if !defined(_azure_VIRTUAL_FOR_TESTS) -#define _azure_VIRTUAL_FOR_TESTS -#endif - -#endif diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp index 4560ed0c2..b99ea48c1 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp @@ -23,7 +23,7 @@ #include #include -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // Define the class name that reads from ConnectionPool private members namespace Azure { namespace Core { namespace Test { class CurlConnectionPool_connectionPoolTest_Test; @@ -43,7 +43,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { * connection pool per application. */ class CurlConnectionPool final { -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // Give access to private to this tests class friend class Azure::Core::Test::CurlConnectionPool_connectionPoolTest_Test; friend class Azure::Core::Test::CurlConnectionPool_uniquePort_Test; diff --git a/sdk/core/azure-core/src/http/curl/curl_session_private.hpp b/sdk/core/azure-core/src/http/curl/curl_session_private.hpp index fb77013f7..d10beaf3c 100644 --- a/sdk/core/azure-core/src/http/curl/curl_session_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_session_private.hpp @@ -18,7 +18,7 @@ #include #include -#ifdef _azure_TESTING_BUILD +#ifdef TESTING_BUILD // Define the class name that reads from ConnectionPool private members namespace Azure { namespace Core { namespace Test { class CurlConnectionPool_connectionPoolTest_Test; @@ -40,7 +40,7 @@ namespace Azure { namespace Core { namespace Http { * transporter to be reusable in multiple pipelines while every call to network is unique. */ class CurlSession final : public Azure::Core::IO::BodyStream { -#ifdef _azure_TESTING_BUILD +#ifdef TESTING_BUILD // Give access to private to this tests class friend class Azure::Core::Test::CurlConnectionPool_connectionPoolTest_Test; friend class Azure::Core::Test::SdkWithLibcurl_DISABLED_globalCleanUp_Test; diff --git a/sdk/core/azure-core/test/ut/retry_policy_test.cpp b/sdk/core/azure-core/test/ut/retry_policy_test.cpp index c18870349..b482418ce 100644 --- a/sdk/core/azure-core/test/ut/retry_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/retry_policy_test.cpp @@ -13,825 +13,823 @@ using namespace Azure::Core::Http; using namespace Azure::Core::Http::Policies; using namespace Azure::Core::Http::Policies::_internal; -namespace Azure { namespace Core { namespace Test { - class TestTransportPolicy final : public HttpPolicy { - private: - std::function()> m_send; +namespace { +class TestTransportPolicy final : public HttpPolicy { +private: + std::function()> m_send; - public: - TestTransportPolicy(std::function()> send) : m_send(send) {} +public: + TestTransportPolicy(std::function()> send) : m_send(send) {} - std::unique_ptr Send( - Request&, - NextHttpPolicy, - Azure::Core::Context const&) const override - { - return m_send(); - } - - std::unique_ptr Clone() const override - { - return std::make_unique(*this); - } - }; - - class RetryPolicyTest final : public RetryPolicy { - private: - std::function - m_shouldRetryOnTransportFailure; - - std::function< - bool(RawResponse const&, RetryOptions const&, int32_t, std::chrono::milliseconds&, double)> - m_shouldRetryOnResponse; - - public: - bool BaseShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const - { - return RetryPolicy::ShouldRetryOnTransportFailure( - retryOptions, attempt, retryAfter, jitterFactor); - } - - bool BaseShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const - { - return RetryPolicy::ShouldRetryOnResponse( - response, retryOptions, attempt, retryAfter, jitterFactor); - } - - RetryPolicyTest( - RetryOptions const& retryOptions, - decltype(m_shouldRetryOnTransportFailure) shouldRetryOnTransportFailure, - decltype(m_shouldRetryOnResponse) shouldRetryOnResponse) - : RetryPolicy(retryOptions), - m_shouldRetryOnTransportFailure( - shouldRetryOnTransportFailure != nullptr // - ? shouldRetryOnTransportFailure - : static_cast( // - [this](auto options, auto attempt, auto retryAfter, auto jitter) { - retryAfter = std::chrono::milliseconds(0); - auto ignore = decltype(retryAfter)(); - return this->BaseShouldRetryOnTransportFailure( - options, attempt, ignore, jitter); - })), - m_shouldRetryOnResponse( - shouldRetryOnResponse != nullptr // - ? shouldRetryOnResponse - : static_cast( // - [this]( - RawResponse const& response, - auto options, - auto attempt, - auto retryAfter, - auto jitter) { - retryAfter = std::chrono::milliseconds(0); - auto ignore = decltype(retryAfter)(); - return this->BaseShouldRetryOnResponse( - response, options, attempt, ignore, jitter); - })) - { - } - - std::unique_ptr Clone() const override - { - return std::make_unique(*this); - } - - protected: - bool ShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const override - { - return m_shouldRetryOnTransportFailure(retryOptions, attempt, retryAfter, jitterFactor); - } - - bool ShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const override - { - return m_shouldRetryOnResponse(response, retryOptions, attempt, retryAfter, jitterFactor); - } - }; - - TEST(RetryPolicy, ShouldRetryOnResponse) + std::unique_ptr Send( + Request&, + NextHttpPolicy, + Azure::Core::Context const&) const override { - using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - - RawResponse const* responsePtrSent = nullptr; - - RawResponse const* responsePtrReceived = nullptr; - RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; - int32_t attemptReceived = -1234; - double jitterReceived = -5678; - - int onTransportFailureInvoked = 0; - int onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](RawResponse const& response, auto options, auto attempt, auto, auto jitter) { - ++onResponseInvoked; - responsePtrReceived = &response; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - })); - - policies.emplace_back(std::make_unique([&]() { - auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - - responsePtrSent = response.get(); - - return response; - })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); - } - - EXPECT_EQ(onTransportFailureInvoked, 0); - EXPECT_EQ(onResponseInvoked, 1); - - EXPECT_NE(responsePtrSent, nullptr); - EXPECT_EQ(responsePtrSent, responsePtrReceived); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 1); - EXPECT_EQ(jitterReceived, -1); - - // 3 attempts - responsePtrSent = nullptr; - - responsePtrReceived = nullptr; - retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; - attemptReceived = -1234; - jitterReceived = -5678; - - onTransportFailureInvoked = 0; - onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](RawResponse const& response, - auto options, - auto attempt, - auto retryAfter, - auto jitter) { - ++onResponseInvoked; - responsePtrReceived = &response; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - retryAfter = 1ms; - return onResponseInvoked < 3; - })); - - policies.emplace_back(std::make_unique([&]() { - auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - - responsePtrSent = response.get(); - - return response; - })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); - } - - EXPECT_EQ(onTransportFailureInvoked, 0); - EXPECT_EQ(onResponseInvoked, 3); - - EXPECT_NE(responsePtrSent, nullptr); - EXPECT_EQ(responsePtrSent, responsePtrReceived); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 3); - EXPECT_EQ(jitterReceived, -1); + return m_send(); } - TEST(RetryPolicy, ShouldRetryOnTransportFailure) + std::unique_ptr Clone() const override { - using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; + return std::make_unique(*this); + } +}; - RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; - int32_t attemptReceived = -1234; - double jitterReceived = -5678; +class RetryPolicyTest final : public RetryPolicy { +private: + std::function + m_shouldRetryOnTransportFailure; - int onTransportFailureInvoked = 0; - int onResponseInvoked = 0; + std::function< + bool(RawResponse const&, RetryOptions const&, int32_t, std::chrono::milliseconds&, double)> + m_shouldRetryOnResponse; - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](auto, auto options, auto attempt, auto, auto jitter) { - ++onResponseInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - })); - - policies.emplace_back(std::make_unique( - []() -> std::unique_ptr { throw TransportException("Test"); })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); - } - - EXPECT_EQ(onTransportFailureInvoked, 1); - EXPECT_EQ(onResponseInvoked, 0); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 1); - EXPECT_EQ(jitterReceived, -1); - - // 3 attempts - retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; - attemptReceived = -1234; - jitterReceived = -5678; - - onTransportFailureInvoked = 0; - onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return onTransportFailureInvoked < 3; - }, - [&](auto, auto options, auto attempt, auto retryAfter, auto jitter) { - ++onResponseInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - retryAfter = 1ms; - return false; - })); - - policies.emplace_back(std::make_unique( - []() -> std::unique_ptr { throw TransportException("Test"); })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); - } - - EXPECT_EQ(onTransportFailureInvoked, 3); - EXPECT_EQ(onResponseInvoked, 0); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 3); - EXPECT_EQ(jitterReceived, -1); +public: + bool BaseShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const + { + return RetryPolicy::ShouldRetryOnTransportFailure( + retryOptions, attempt, retryAfter, jitterFactor); } - class RetryLogic final : private RetryPolicy { - RetryLogic() : RetryPolicy(RetryOptions()) {} - ~RetryLogic() {} - - static RetryLogic const g_retryPolicy; - - public: - static bool TestShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) - { - return g_retryPolicy.ShouldRetryOnTransportFailure( - retryOptions, attempt, retryAfter, jitterFactor); - } - - static bool TestShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) - { - return g_retryPolicy.ShouldRetryOnResponse( - response, retryOptions, attempt, retryAfter, jitterFactor); - } - }; - - RetryLogic const RetryLogic::g_retryPolicy; - - TEST(RetryPolicy, Exponential) + bool BaseShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const { - using namespace std::chrono_literals; - - RetryOptions const options{3, 1s, 2min, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 4s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } + return RetryPolicy::ShouldRetryOnResponse( + response, retryOptions, attempt, retryAfter, jitterFactor); } - TEST(RetryPolicy, LessThan2Retries) + RetryPolicyTest( + RetryOptions const& retryOptions, + decltype(m_shouldRetryOnTransportFailure) shouldRetryOnTransportFailure, + decltype(m_shouldRetryOnResponse) shouldRetryOnResponse) + : RetryPolicy(retryOptions), + m_shouldRetryOnTransportFailure( + shouldRetryOnTransportFailure != nullptr // + ? shouldRetryOnTransportFailure + : static_cast( // + [this](auto options, auto attempt, auto retryAfter, auto jitter) { + retryAfter = std::chrono::milliseconds(0); + auto ignore = decltype(retryAfter)(); + return this->BaseShouldRetryOnTransportFailure( + options, attempt, ignore, jitter); + })), + m_shouldRetryOnResponse( + shouldRetryOnResponse != nullptr // + ? shouldRetryOnResponse + : static_cast( // + [this]( + RawResponse const& response, + auto options, + auto attempt, + auto retryAfter, + auto jitter) { + retryAfter = std::chrono::milliseconds(0); + auto ignore = decltype(retryAfter)(); + return this->BaseShouldRetryOnResponse( + response, options, attempt, ignore, jitter); + })) { - using namespace std::chrono_literals; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({1, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({0, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({-1, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } } - TEST(RetryPolicy, NotExceedingMaxRetryDelay) + std::unique_ptr Clone() const override { - using namespace std::chrono_literals; - - RetryOptions const options{7, 1s, 20s, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 4s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 8s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 5, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 16s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 6, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 20s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 7, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 20s); - } + return std::make_unique(*this); } - TEST(RetryPolicy, NotExceedingInt32Max) +protected: + bool ShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const override { - using namespace std::chrono_literals; - - RetryOptions const options{35, 1s, 9999999999999s, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 31, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1073741824s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 32, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 33, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 34, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } + return m_shouldRetryOnTransportFailure(retryOptions, attempt, retryAfter, jitterFactor); } - TEST(RetryPolicy, Jitter) + bool ShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const override { - using namespace std::chrono_literals; + return m_shouldRetryOnResponse(response, retryOptions, attempt, retryAfter, jitterFactor); + } +}; +} // namespace - RetryOptions const options{3, 10s, 20min, {}}; +TEST(RetryPolicy, ShouldRetryOnResponse) +{ + using namespace std::chrono_literals; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 0.8); + RawResponse const* responsePtrSent = nullptr; - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 8s); - } + RawResponse const* responsePtrReceived = nullptr; + RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; + int32_t attemptReceived = -1234; + double jitterReceived = -5678; - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.3); + int onTransportFailureInvoked = 0; + int onResponseInvoked = 0; - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 13s); - } + { + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 0.8); + return false; + }, + [&](RawResponse const& response, auto options, auto attempt, auto, auto jitter) { + ++onResponseInvoked; + responsePtrReceived = &response; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 16s); - } + return false; + })); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.3); + policies.emplace_back(std::make_unique([&]() { + auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 26s); - } + responsePtrSent = response.get(); + + return response; + })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); } - TEST(RetryPolicy, JitterExtremes) + EXPECT_EQ(onTransportFailureInvoked, 0); + EXPECT_EQ(onResponseInvoked, 1); + + EXPECT_NE(responsePtrSent, nullptr); + EXPECT_EQ(responsePtrSent, responsePtrReceived); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 1); + EXPECT_EQ(jitterReceived, -1); + + // 3 attempts + responsePtrSent = nullptr; + + responsePtrReceived = nullptr; + retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; + attemptReceived = -1234; + jitterReceived = -5678; + + onTransportFailureInvoked = 0; + onResponseInvoked = 0; + { - using namespace std::chrono_literals; + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 1ms, 2min, {}}, 1, retryAfter, 0.8); + return false; + }, + [&](RawResponse const& response, auto options, auto attempt, auto retryAfter, auto jitter) { + ++onResponseInvoked; + responsePtrReceived = &response; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 0ms); - } + retryAfter = 1ms; + return onResponseInvoked < 3; + })); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 2ms, 2min, {}}, 1, retryAfter, 0.8); + policies.emplace_back(std::make_unique([&]() { + auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1ms); - } + responsePtrSent = response.get(); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 2, retryAfter, 1.3); + return response; + })); - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 21s); - } + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 3, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 21s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnTransportFailure( - {35, 1s, 9999999999999s, {}}, 33, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2791728741100ms); - } + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); } - TEST(RetryPolicy, HttpStatusCode) + EXPECT_EQ(onTransportFailureInvoked, 0); + EXPECT_EQ(onResponseInvoked, 3); + + EXPECT_NE(responsePtrSent, nullptr); + EXPECT_EQ(responsePtrSent, responsePtrReceived); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 3); + EXPECT_EQ(jitterReceived, -1); +} + +TEST(RetryPolicy, ShouldRetryOnTransportFailure) +{ + using namespace std::chrono_literals; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; + + RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; + int32_t attemptReceived = -1234; + double jitterReceived = -5678; + + int onTransportFailureInvoked = 0; + int onResponseInvoked = 0; + { - using namespace std::chrono_literals; + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), - {3, 3210s, 3h, {HttpStatusCode::RequestTimeout}}, - 1, - retryAfter, - 1.0); + return false; + }, + [&](auto, auto options, auto attempt, auto, auto jitter) { + ++onResponseInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 3210s); - } + return false; + })); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), - {3, 654s, 3h, {HttpStatusCode::Ok}}, - 1, - retryAfter, - 1.0); + policies.emplace_back(std::make_unique( + []() -> std::unique_ptr { throw TransportException("Test"); })); - EXPECT_EQ(shouldRetry, false); - } + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::Ok, ""), - {3, 987s, 3h, {HttpStatusCode::Ok}}, - 1, - retryAfter, - 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 987s); - } + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); } - TEST(RetryPolicy, RetryAfterMs) + EXPECT_EQ(onTransportFailureInvoked, 1); + EXPECT_EQ(onResponseInvoked, 0); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 1); + EXPECT_EQ(jitterReceived, -1); + + // 3 attempts + retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; + attemptReceived = -1234; + jitterReceived = -5678; + + onTransportFailureInvoked = 0; + onResponseInvoked = 0; + { - using namespace std::chrono_literals; + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("rEtRy-aFtEr-mS", "1234"); + return onTransportFailureInvoked < 3; + }, + [&](auto, auto options, auto attempt, auto retryAfter, auto jitter) { + ++onResponseInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.3); + retryAfter = 1ms; + return false; + })); - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1234ms); - } + policies.emplace_back(std::make_unique( + []() -> std::unique_ptr { throw TransportException("Test"); })); - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("X-mS-ReTrY-aFtEr-MS", "5678"); + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 5678ms); - } + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); } - TEST(RetryPolicy, RetryAfter) + EXPECT_EQ(onTransportFailureInvoked, 3); + EXPECT_EQ(onResponseInvoked, 0); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 3); + EXPECT_EQ(jitterReceived, -1); +} + +namespace { +class RetryLogic final : private RetryPolicy { + RetryLogic() : RetryPolicy(RetryOptions()) {} + ~RetryLogic() {} + + static RetryLogic const g_retryPolicy; + +public: + static bool TestShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) { - using namespace std::chrono_literals; - - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("rEtRy-aFtEr", "90"); - - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.1); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 90s); - } + return g_retryPolicy.ShouldRetryOnTransportFailure( + retryOptions, attempt, retryAfter, jitterFactor); } - TEST(RetryPolicy, LogMessages) + static bool TestShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) { - using Azure::Core::Diagnostics::Logger; + return g_retryPolicy.ShouldRetryOnResponse( + response, retryOptions, attempt, retryAfter, jitterFactor); + } +}; - struct Log +RetryLogic const RetryLogic::g_retryPolicy; +} // namespace + +TEST(RetryPolicy, Exponential) +{ + using namespace std::chrono_literals; + + RetryOptions const options{3, 1s, 2min, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 4s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } +} + +TEST(RetryPolicy, LessThan2Retries) +{ + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({1, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({0, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({-1, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } +} + +TEST(RetryPolicy, NotExceedingMaxRetryDelay) +{ + using namespace std::chrono_literals; + + RetryOptions const options{7, 1s, 20s, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 4s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 8s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 5, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 16s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 6, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 20s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 7, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 20s); + } +} + +TEST(RetryPolicy, NotExceedingInt32Max) +{ + using namespace std::chrono_literals; + + RetryOptions const options{35, 1s, 9999999999999s, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 31, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1073741824s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 32, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 33, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 34, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } +} + +TEST(RetryPolicy, Jitter) +{ + using namespace std::chrono_literals; + + RetryOptions const options{3, 10s, 20min, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 8s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 13s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 16s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 26s); + } +} + +TEST(RetryPolicy, JitterExtremes) +{ + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 1ms, 2min, {}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 0ms); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 2ms, 2min, {}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1ms); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 2, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 21s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 3, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 21s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnTransportFailure( + {35, 1s, 9999999999999s, {}}, 33, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2791728741100ms); + } +} + +TEST(RetryPolicy, HttpStatusCode) +{ + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), + {3, 3210s, 3h, {HttpStatusCode::RequestTimeout}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 3210s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), + {3, 654s, 3h, {HttpStatusCode::Ok}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, false); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::Ok, ""), + {3, 987s, 3h, {HttpStatusCode::Ok}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 987s); + } +} + +TEST(RetryPolicy, RetryAfterMs) +{ + using namespace std::chrono_literals; + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("rEtRy-aFtEr-mS", "1234"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1234ms); + } + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("X-mS-ReTrY-aFtEr-MS", "5678"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 5678ms); + } +} + +TEST(RetryPolicy, RetryAfter) +{ + using namespace std::chrono_literals; + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("rEtRy-aFtEr", "90"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.1); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 90s); + } +} + +TEST(RetryPolicy, LogMessages) +{ + using Azure::Core::Diagnostics::Logger; + + struct Log + { + struct Entry { - struct Entry + Logger::Level Level; + std::string Message; + }; + + std::vector Entries; + + Log() + { + Logger::SetLevel(Logger::Level::Informational); + Logger::SetListener([&](auto lvl, auto msg) { Entries.emplace_back(Entry{lvl, msg}); }); + } + + ~Log() + { + Logger::SetListener(nullptr); + Logger::SetLevel(Logger::Level::Warning); + } + + } log; + + { + using namespace std::chrono_literals; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::InternalServerError}}; + + auto requestNumber = 0; + + std::vector> policies; + policies.emplace_back(std::make_unique(retryOptions, nullptr, nullptr)); + policies.emplace_back(std::make_unique([&]() { + ++requestNumber; + + if (requestNumber == 1) { - Logger::Level Level; - std::string Message; - }; - - std::vector Entries; - - Log() - { - Logger::SetLevel(Logger::Level::Informational); - Logger::SetListener([&](auto lvl, auto msg) { Entries.emplace_back(Entry{lvl, msg}); }); + throw TransportException("Cable Unplugged"); } - ~Log() - { - Logger::SetListener(nullptr); - Logger::SetLevel(Logger::Level::Warning); - } + return std::make_unique( + 1, + 1, + requestNumber == 2 ? HttpStatusCode::InternalServerError + : HttpStatusCode::ServiceUnavailable, + "Test"); + })); - } log; + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - { - using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::InternalServerError}}; - - auto requestNumber = 0; - - std::vector> policies; - policies.emplace_back(std::make_unique(retryOptions, nullptr, nullptr)); - policies.emplace_back(std::make_unique([&]() { - ++requestNumber; - - if (requestNumber == 1) - { - throw TransportException("Cable Unplugged"); - } - - return std::make_unique( - 1, - 1, - requestNumber == 2 ? HttpStatusCode::InternalServerError - : HttpStatusCode::ServiceUnavailable, - "Test"); - })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); - } - - EXPECT_EQ(log.Entries.size(), 5); - - EXPECT_EQ(log.Entries[0].Level, Logger::Level::Warning); - EXPECT_EQ(log.Entries[0].Message, "HTTP Transport error: Cable Unplugged"); - - EXPECT_EQ(log.Entries[1].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[1].Message, "HTTP Retry attempt #1 will be made in 0ms."); - - EXPECT_EQ(log.Entries[2].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[2].Message, "HTTP status code 500 will be retried."); - - EXPECT_EQ(log.Entries[3].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[3].Message, "HTTP Retry attempt #2 will be made in 0ms."); - - EXPECT_EQ(log.Entries[4].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[4].Message, "HTTP status code 503 won't be retried."); + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); } -}}} // namespace Azure::Core::Test + + EXPECT_EQ(log.Entries.size(), 5); + + EXPECT_EQ(log.Entries[0].Level, Logger::Level::Warning); + EXPECT_EQ(log.Entries[0].Message, "HTTP Transport error: Cable Unplugged"); + + EXPECT_EQ(log.Entries[1].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[1].Message, "HTTP Retry attempt #1 will be made in 0ms."); + + EXPECT_EQ(log.Entries[2].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[2].Message, "HTTP status code 500 will be retried."); + + EXPECT_EQ(log.Entries[3].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[3].Message, "HTTP Retry attempt #2 will be made in 0ms."); + + EXPECT_EQ(log.Entries[4].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[4].Message, "HTTP status code 503 won't be retried."); +} diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index 50df76415..a53226b39 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -126,7 +126,7 @@ az_rtti_setup( if(BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) # tests if (NOT AZ_ALL_LIBRARIES OR FETCH_SOURCE_DEPS) diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp index be24fd050..981296692 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -13,18 +13,11 @@ #include #include #include -#include #include #include #include -#if defined(_azure_TESTING_BUILD) -namespace Azure { namespace Identity { namespace Test { - class AzureCliTestCredential; -}}} // namespace Azure::Identity::Test -#endif - namespace Azure { namespace Identity { /** * @brief Options for configuring the #Azure::Identity::AzureCliCredential. @@ -56,12 +49,11 @@ namespace Azure { namespace Identity { * @brief Enables authentication to Microsoft Entra ID using Azure CLI to obtain an access * token. */ - class AzureCliCredential _azure_NON_FINAL_FOR_TESTS : public Core::Credentials::TokenCredential { - -#if defined(_azure_TESTING_BUILD) - friend class Azure::Identity::Test::AzureCliTestCredential; + class AzureCliCredential +#if !defined(TESTING_BUILD) + final #endif - + : public Core::Credentials::TokenCredential { protected: /** @brief The cache for the access token. */ _detail::TokenCache m_tokenCache; @@ -114,12 +106,13 @@ namespace Azure { namespace Identity { Core::Credentials::TokenRequestContext const& tokenRequestContext, Core::Context const& context) const override; +#if !defined(TESTING_BUILD) private: - _azure_VIRTUAL_FOR_TESTS std::string GetAzCommand( - std::string const& scopes, - std::string const& tenantId) const; - - _azure_VIRTUAL_FOR_TESTS int GetLocalTimeToUtcDiffSeconds() const; +#else + protected: +#endif + virtual std::string GetAzCommand(std::string const& scopes, std::string const& tenantId) const; + virtual int GetLocalTimeToUtcDiffSeconds() const; }; }} // namespace Azure::Identity diff --git a/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp index a555e432a..a209f94b9 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/default_azure_credential.hpp @@ -13,7 +13,7 @@ #include -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) class DefaultAzureCredential_CachingCredential_Test; #endif @@ -42,7 +42,7 @@ namespace Azure { namespace Identity { */ class DefaultAzureCredential final : public Core::Credentials::TokenCredential { -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // make tests classes friends to validate caching friend class ::DefaultAzureCredential_CachingCredential_Test; #endif diff --git a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp index 97994364c..6ca49b82b 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp @@ -10,7 +10,6 @@ #pragma once #include -#include #include #include @@ -20,30 +19,26 @@ #include #include -#if defined(_azure_TESTING_BUILD) -// Define the class used from tests to validate retry policy -namespace Azure { namespace Identity { namespace Test { - class TestableTokenCache; -}}} // namespace Azure::Identity::Test -#endif - namespace Azure { namespace Identity { namespace _detail { /** * @brief Access token cache. * */ - class TokenCache _azure_NON_FINAL_FOR_TESTS { - -#if defined(_azure_TESTING_BUILD) - friend class Azure::Identity::Test::TestableTokenCache; + class TokenCache +#if !defined(TESTING_BUILD) + final #endif - + { +#if !defined(TESTING_BUILD) private: +#else + protected: +#endif // A test hook that gets invoked before cache write lock gets acquired. - _azure_VIRTUAL_FOR_TESTS void OnBeforeCacheWriteLock() const {}; + virtual void OnBeforeCacheWriteLock() const {}; // A test hook that gets invoked before item write lock gets acquired. - _azure_VIRTUAL_FOR_TESTS void OnBeforeItemWriteLock() const {}; + virtual void OnBeforeItemWriteLock() const {}; struct CacheKey { @@ -68,6 +63,7 @@ namespace Azure { namespace Identity { namespace _detail { mutable std::map, CacheKeyComparator> m_cache; mutable std::shared_timed_mutex m_cacheMutex; + private: TokenCache(TokenCache const&) = delete; TokenCache& operator=(TokenCache const&) = delete; diff --git a/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp b/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp index 95415044b..3dcc39d37 100644 --- a/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp +++ b/sdk/identity/azure-identity/src/private/chained_token_credential_impl.hpp @@ -9,7 +9,7 @@ #include #include -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) class DefaultAzureCredential_CachingCredential_Test; #endif @@ -17,7 +17,7 @@ namespace Azure { namespace Identity { namespace _detail { class ChainedTokenCredentialImpl final { -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) // make tests classes friends to validate caching friend class ::DefaultAzureCredential_CachingCredential_Test; #endif diff --git a/sdk/identity/azure-identity/src/token_cache.cpp b/sdk/identity/azure-identity/src/token_cache.cpp index 5c470d658..acdc4be69 100644 --- a/sdk/identity/azure-identity/src/token_cache.cpp +++ b/sdk/identity/azure-identity/src/token_cache.cpp @@ -39,7 +39,7 @@ std::shared_ptr TokenCache::GetOrCreateValue( } } -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) OnBeforeCacheWriteLock(); #endif @@ -101,7 +101,7 @@ AccessToken TokenCache::GetToken( } } -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) OnBeforeItemWriteLock(); #endif diff --git a/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp index 0ff9087b7..cd7994a6d 100644 --- a/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/azure_cli_credential_test.cpp @@ -22,602 +22,617 @@ using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Identity::AzureCliCredentialOptions; -namespace Azure { namespace Identity { namespace Test { - constexpr auto InfiniteCommand = +namespace { +constexpr auto InfiniteCommand = #if defined(AZ_PLATFORM_WINDOWS) - "for /l %q in (0) do timeout 10"; + "for /l %q in (0) do timeout 10"; #else - "while true; do sleep 10; done" + "while true; do sleep 10; done" #endif - ; +; - constexpr auto EmptyOutputCommand = +constexpr auto EmptyOutputCommand = #if defined(AZ_PLATFORM_WINDOWS) - "rem"; + "rem"; #else - "clear" + "clear" #endif - ; +; - std::string EchoCommand(std::string const text) +std::string EchoCommand(std::string const text) +{ +#if defined(AZ_PLATFORM_WINDOWS) + return std::string("echo ") + text; +#else + return std::string("echo \'") + text + "\'"; +#endif +} + +class AzureCliTestCredential : public AzureCliCredential { +private: + std::string m_command; + int m_localTimeToUtcDiffSeconds = 0; + + std::string GetAzCommand(std::string const& resource, std::string const& tenantId) const override { -#if defined(AZ_PLATFORM_WINDOWS) - return std::string("echo ") + text; -#else - return std::string("echo \'") + text + "\'"; -#endif + static_cast(resource); + static_cast(tenantId); + + return m_command; } - class AzureCliTestCredential : public AzureCliCredential { - private: - std::string m_command; - int m_localTimeToUtcDiffSeconds = 0; + int GetLocalTimeToUtcDiffSeconds() const override { return m_localTimeToUtcDiffSeconds; } - std::string GetAzCommand(std::string const& resource, std::string const& tenantId) - const override - { - static_cast(resource); - static_cast(tenantId); +public: + explicit AzureCliTestCredential(std::string command) : m_command(std::move(command)) {} - return m_command; - } - - int GetLocalTimeToUtcDiffSeconds() const override { return m_localTimeToUtcDiffSeconds; } - - public: - explicit AzureCliTestCredential(std::string command) : m_command(std::move(command)) {} - - explicit AzureCliTestCredential(std::string command, AzureCliCredentialOptions const& options) - : AzureCliCredential(options), m_command(std::move(command)) - { - } - - explicit AzureCliTestCredential(std::string command, TokenCredentialOptions const& options) - : AzureCliCredential(options), m_command(std::move(command)) - { - } - - std::string GetOriginalAzCommand(std::string const& resource, std::string const& tenantId) const - { - return AzureCliCredential::GetAzCommand(resource, tenantId); - } - - decltype(m_tenantId) const& GetTenantId() const { return m_tenantId; } - decltype(m_cliProcessTimeout) const& GetCliProcessTimeout() const - { - return m_cliProcessTimeout; - } - - void SetLocalTimeToUtcDiffSeconds(int diff) { m_localTimeToUtcDiffSeconds = diff; } - }; - -#if !defined(AZ_PLATFORM_WINDOWS) \ - || (!defined(WINAPI_PARTITION_DESKTOP) || WINAPI_PARTITION_DESKTOP) // not UWP - TEST(AzureCliCredential, Success) -#else - TEST(AzureCliCredential, NotAvailable) -#endif + explicit AzureCliTestCredential(std::string command, AzureCliCredentialOptions const& options) + : AzureCliCredential(options), m_command(std::move(command)) { - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," - "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," - "\"tenant\":\"72f988bf-86f1-41af-91ab-2d7cd011db47\"," - "\"tokenType\":\"Bearer\"}"; + } - AzureCliTestCredential const azCliCred(EchoCommand(Token)); + explicit AzureCliTestCredential(std::string command, TokenCredentialOptions const& options) + : AzureCliCredential(options), m_command(std::move(command)) + { + } + + std::string GetOriginalAzCommand(std::string const& resource, std::string const& tenantId) const + { + return AzureCliCredential::GetAzCommand(resource, tenantId); + } + + decltype(m_tenantId) const& GetTenantId() const { return m_tenantId; } + decltype(m_cliProcessTimeout) const& GetCliProcessTimeout() const { return m_cliProcessTimeout; } + + void SetLocalTimeToUtcDiffSeconds(int diff) { m_localTimeToUtcDiffSeconds = diff; } +}; +} // namespace - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); #if !defined(AZ_PLATFORM_WINDOWS) \ || (!defined(WINAPI_PARTITION_DESKTOP) || WINAPI_PARTITION_DESKTOP) // not UWP - auto const token = azCliCred.GetToken(trc, {}); +TEST(AzureCliCredential, Success) +#else +TEST(AzureCliCredential, NotAvailable) +#endif +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," + "\"tenant\":\"72f988bf-86f1-41af-91ab-2d7cd011db47\"," + "\"tokenType\":\"Bearer\"}"; - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + AzureCliTestCredential const azCliCred(EchoCommand(Token)); - EXPECT_EQ( - token.ExpiresOn, - DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); +#if !defined(AZ_PLATFORM_WINDOWS) \ + || (!defined(WINAPI_PARTITION_DESKTOP) || WINAPI_PARTITION_DESKTOP) // not UWP + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); #else // UWP - // The credential should throw during GetToken() and not during construction, because it - // allows customers to put it into ChainedTokenCredential and successfully use it there - // without writing ifdefs for UWP. It is not too late to throw - for example, if Azure CLI is - // not installed, then the credential will also find out during GetToken() and not during - // construction (if we had to find out during the construction, we'd have to fire up some 'az' - // command in constructor; again, that would also make it hard to put the credential into - // ChainedTokenCredential). - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + // The credential should throw during GetToken() and not during construction, because it allows + // customers to put it into ChainedTokenCredential and successfully use it there without writing + // ifdefs for UWP. It is not too late to throw - for example, if Azure CLI is not installed, then + // the credential will also find out during GetToken() and not during construction (if we had to + // find out during the construction, we'd have to fire up some 'az' command in constructor; again, + // that would also make it hard to put the credential into ChainedTokenCredential). + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); #endif // UWP - } +} #if !defined(AZ_PLATFORM_WINDOWS) \ || (!defined(WINAPI_PARTITION_DESKTOP) || WINAPI_PARTITION_DESKTOP) // not UWP - TEST(AzureCliCredential, Error) - { - using Azure::Core::Diagnostics::Logger; - using LogMsgVec = std::vector>; - LogMsgVec log; - Logger::SetLevel(Logger::Level::Informational); - Logger::SetListener([&](auto lvl, auto msg) { log.push_back(std::make_pair(lvl, msg)); }); +TEST(AzureCliCredential, Error) +{ + using Azure::Core::Diagnostics::Logger; + using LogMsgVec = std::vector>; + LogMsgVec log; + Logger::SetLevel(Logger::Level::Informational); + Logger::SetListener([&](auto lvl, auto msg) { log.push_back(std::make_pair(lvl, msg)); }); - AzureCliTestCredential const azCliCred( - EchoCommand("ERROR: Please run az login to setup account.")); + AzureCliTestCredential const azCliCred( + EchoCommand("ERROR: Please run az login to setup account.")); - EXPECT_EQ(log.size(), LogMsgVec::size_type(1)); - EXPECT_EQ(log[0].first, Logger::Level::Informational); - EXPECT_EQ( - log[0].second, - "Identity: AzureCliCredential created." - "\nSuccessful creation does not guarantee further successful token retrieval."); + EXPECT_EQ(log.size(), LogMsgVec::size_type(1)); + EXPECT_EQ(log[0].first, Logger::Level::Informational); + EXPECT_EQ( + log[0].second, + "Identity: AzureCliCredential created." + "\nSuccessful creation does not guarantee further successful token retrieval."); - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); - log.clear(); - auto const errorMsg = "Identity: AzureCliCredential didn't get the token:" - " \"ERROR: Please run az login to setup account." + log.clear(); + auto const errorMsg = "Identity: AzureCliCredential didn't get the token:" + " \"ERROR: Please run az login to setup account." #if defined(AZ_PLATFORM_WINDOWS) - "\r" + "\r" #endif - "\n\""; + "\n\""; - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - EXPECT_EQ(log.size(), LogMsgVec::size_type(1)); - EXPECT_EQ(log[0].first, Logger::Level::Warning); - EXPECT_EQ(log[0].second, errorMsg); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + EXPECT_EQ(log.size(), LogMsgVec::size_type(1)); + EXPECT_EQ(log[0].first, Logger::Level::Warning); + EXPECT_EQ(log[0].second, errorMsg); - Logger::SetListener(nullptr); - } + Logger::SetListener(nullptr); +} - TEST(AzureCliCredential, GetCredentialName) +TEST(AzureCliCredential, GetCredentialName) +{ + AzureCliTestCredential const cred(EmptyOutputCommand); + EXPECT_EQ(cred.GetCredentialName(), "AzureCliCredential"); +} + +TEST(AzureCliCredential, EmptyOutput) +{ + AzureCliTestCredential const azCliCred(EmptyOutputCommand); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); +} + +TEST(AzureCliCredential, BigToken) +{ + std::string accessToken; { - AzureCliTestCredential const cred(EmptyOutputCommand); - EXPECT_EQ(cred.GetCredentialName(), "AzureCliCredential"); - } - - TEST(AzureCliCredential, EmptyOutput) - { - AzureCliTestCredential const azCliCred(EmptyOutputCommand); - - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - } - - TEST(AzureCliCredential, BigToken) - { - std::string accessToken; + std::string const tokenPart = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + auto const nIterations = ((4 * 1024) / tokenPart.size()) + 1; + for (auto i = 0; i < static_cast(nIterations); ++i) { - std::string const tokenPart = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - auto const nIterations = ((4 * 1024) / tokenPart.size()) + 1; - for (auto i = 0; i < static_cast(nIterations); ++i) - { - accessToken += tokenPart; - } + accessToken += tokenPart; + } + } + + AzureCliTestCredential const azCliCred(EchoCommand( + std::string("{\"accessToken\":\"") + accessToken + + "\",\"expiresOn\":\"2022-08-24 00:43:08.000000\"}")); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, accessToken); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, ExpiresIn) +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\",\"expiresIn\":30}"; + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const timestampBefore = std::chrono::system_clock::now(); + auto const token = azCliCred.GetToken(trc, {}); + auto const timestampAfter = std::chrono::system_clock::now(); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_GE(token.ExpiresOn, timestampBefore + std::chrono::seconds(30)); + EXPECT_LE(token.ExpiresOn, timestampAfter + std::chrono::seconds(30)); +} + +TEST(AzureCliCredential, ExpiresOnUnixTimestampInt) +{ + // 'expires_on' is 1700692424, which is a Unix timestamp of a date in 2023. + // 'ExpiresOn' is a date in 2022. + // The test checks that when both are present, 'expires_on' value (2023) is taken, + // and not that of 'ExpiresOn'. + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," // <-- 2022 + "\"expires_on\":1700692424}"; // <-- 2023 + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2023-11-22T22:33:44.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, ExpiresOnUnixTimestampString) +{ + // 'expires_on' is 1700692424, which is a Unix timestamp of a date in 2023. + // 'expiresOn' is a date in 2022. + // The test checks that when both are present, 'expires_on' value (2023) is taken, + // and not that of 'expiresOn'. + // The test is similar to the one above, but the Unix timestamp is represented as string + // containing an integer. + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," // <-- 2022 + "\"expires_on\":\"1700692424\"}"; // <-- 2023 + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2023-11-22T22:33:44.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, TimedOut) +{ + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::seconds(2); + AzureCliTestCredential const azCliCred(InfiniteCommand, options); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); +} + +TEST(AzureCliCredential, ContextCancelled) +{ + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliTestCredential const azCliCred(InfiniteCommand, options); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + + auto context = Context::ApplicationContext.WithDeadline( + std::chrono::system_clock::now() + std::chrono::hours(24)); + + std::atomic thread1Started(false); + + std::thread thread1([&]() { + thread1Started = true; + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, context)), AuthenticationException); + }); + + std::thread thread2([&]() { + while (!thread1Started) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - AzureCliTestCredential const azCliCred(EchoCommand( - std::string("{\"accessToken\":\"") + accessToken - + "\",\"expiresOn\":\"2022-08-24 00:43:08.000000\"}")); + std::this_thread::sleep_for(std::chrono::seconds(2)); + context.Cancel(); + }); - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); + thread1.join(); + thread2.join(); +} - auto const token = azCliCred.GetToken(trc, {}); - - EXPECT_EQ(token.Token, accessToken); - - EXPECT_EQ( - token.ExpiresOn, - DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); - } - - TEST(AzureCliCredential, ExpiresIn) +TEST(AzureCliCredential, Defaults) +{ { - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\",\"expiresIn\":30}"; + AzureCliCredentialOptions const DefaultOptions; - AzureCliTestCredential const azCliCred(EchoCommand(Token)); + { + AzureCliTestCredential azCliCred({}); + EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + } - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - - auto const timestampBefore = std::chrono::system_clock::now(); - auto const token = azCliCred.GetToken(trc, {}); - auto const timestampAfter = std::chrono::system_clock::now(); - - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - - EXPECT_GE(token.ExpiresOn, timestampBefore + std::chrono::seconds(30)); - EXPECT_LE(token.ExpiresOn, timestampAfter + std::chrono::seconds(30)); + { + TokenCredentialOptions const options; + AzureCliTestCredential azCliCred({}, options); + EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); + } } - TEST(AzureCliCredential, ExpiresOnUnixTimestampInt) - { - // 'expires_on' is 1700692424, which is a Unix timestamp of a date in 2023. - // 'ExpiresOn' is a date in 2022. - // The test checks that when both are present, 'expires_on' value (2023) is taken, - // and not that of 'ExpiresOn'. - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," - "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," // <-- 2022 - "\"expires_on\":1700692424}"; // <-- 2023 - - AzureCliTestCredential const azCliCred(EchoCommand(Token)); - - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - - auto const token = azCliCred.GetToken(trc, {}); - - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - - EXPECT_EQ( - token.ExpiresOn, - DateTime::Parse("2023-11-22T22:33:44.000000Z", DateTime::DateFormat::Rfc3339)); - } - - TEST(AzureCliCredential, ExpiresOnUnixTimestampString) - { - // 'expires_on' is 1700692424, which is a Unix timestamp of a date in 2023. - // 'expiresOn' is a date in 2022. - // The test checks that when both are present, 'expires_on' value (2023) is taken, - // and not that of 'expiresOn'. - // The test is similar to the one above, but the Unix timestamp is represented as string - // containing an integer. - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," - "\"expiresOn\":\"2022-08-24 00:43:08.000000\"," // <-- 2022 - "\"expires_on\":\"1700692424\"}"; // <-- 2023 - - AzureCliTestCredential const azCliCred(EchoCommand(Token)); - - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - - auto const token = azCliCred.GetToken(trc, {}); - - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - - EXPECT_EQ( - token.ExpiresOn, - DateTime::Parse("2023-11-22T22:33:44.000000Z", DateTime::DateFormat::Rfc3339)); - } - - TEST(AzureCliCredential, TimedOut) { AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::seconds(2); - AzureCliTestCredential const azCliCred(InfiniteCommand, options); + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.CliProcessTimeout = std::chrono::seconds(12345); + + AzureCliTestCredential azCliCred({}, options); + + EXPECT_EQ(azCliCred.GetTenantId(), "01234567-89AB-CDEF-0123-456789ABCDEF"); + EXPECT_EQ(azCliCred.GetCliProcessTimeout(), std::chrono::seconds(12345)); + } +} + +TEST(AzureCliCredential, CmdLine) +{ + AzureCliTestCredential azCliCred({}); + + auto const cmdLineWithoutTenant + = azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {}); + + auto const cmdLineWithTenant = azCliCred.GetOriginalAzCommand( + "https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF"); + + EXPECT_EQ( + cmdLineWithoutTenant, + "az account get-access-token --output json --scope \"https://storage.azure.com/.default\""); + + EXPECT_EQ( + cmdLineWithTenant, + "az account get-access-token --output json --scope \"https://storage.azure.com/.default\"" + " --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\""); +} + +TEST(AzureCliCredential, UnsafeChars) +{ + std::string const Exploit = std::string("\" | echo OWNED | ") + InfiniteCommand + " | echo \""; + + { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.TenantId += Exploit; + AzureCliCredential azCliCred(options); TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); } - TEST(AzureCliCredential, ContextCancelled) { AzureCliCredentialOptions options; options.CliProcessTimeout = std::chrono::hours(24); - AzureCliTestCredential const azCliCred(InfiniteCommand, options); + AzureCliCredential azCliCred(options); TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); + trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Exploit); - auto context = Context::ApplicationContext.WithDeadline( - std::chrono::system_clock::now() + std::chrono::hours(24)); - - std::atomic thread1Started(false); - - std::thread thread1([&]() { - thread1Started = true; - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, context)), AuthenticationException); - }); - - std::thread thread2([&]() { - while (!thread1Started) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - - std::this_thread::sleep_for(std::chrono::seconds(2)); - context.Cancel(); - }); - - thread1.join(); - thread2.join(); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); } +} - TEST(AzureCliCredential, Defaults) +class ParameterizedTestForDisallowedChars : public ::testing::TestWithParam { +protected: + std::string value; +}; + +TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) +{ + std::string const InvalidValue = GetParam(); + + // Tenant ID test via AzureCliCredentialOptions directly. { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.TenantId += InvalidValue; + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + + try { - AzureCliCredentialOptions const DefaultOptions; - - { - AzureCliTestCredential azCliCred({}); - EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); - EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); - } - - { - TokenCredentialOptions const options; - AzureCliTestCredential azCliCred({}, options); - EXPECT_EQ(azCliCred.GetTenantId(), DefaultOptions.TenantId); - EXPECT_EQ(azCliCred.GetCliProcessTimeout(), DefaultOptions.CliProcessTimeout); - } + auto const token = azCliCred.GetToken(trc, {}); } - + catch (AuthenticationException const& e) { - AzureCliCredentialOptions options; - options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.CliProcessTimeout = std::chrono::seconds(12345); - - AzureCliTestCredential azCliCred({}, options); - - EXPECT_EQ(azCliCred.GetTenantId(), "01234567-89AB-CDEF-0123-456789ABCDEF"); - EXPECT_EQ(azCliCred.GetCliProcessTimeout(), std::chrono::seconds(12345)); + EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); } } - TEST(AzureCliCredential, CmdLine) + // Tenant ID test via TokenRequestContext, using a wildcard for AdditionallyAllowedTenants. { - AzureCliTestCredential azCliCred({}); + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + options.AdditionallyAllowedTenants.push_back("*"); + AzureCliCredential azCliCred(options); - auto const cmdLineWithoutTenant - = azCliCred.GetOriginalAzCommand("https://storage.azure.com/.default", {}); - - auto const cmdLineWithTenant = azCliCred.GetOriginalAzCommand( - "https://storage.azure.com/.default", "01234567-89AB-CDEF-0123-456789ABCDEF"); - - EXPECT_EQ( - cmdLineWithoutTenant, - "az account get-access-token --output json --scope " - "\"https://storage.azure.com/.default\""); - - EXPECT_EQ( - cmdLineWithTenant, - "az account get-access-token --output json --scope \"https://storage.azure.com/.default\"" - " --tenant \"01234567-89AB-CDEF-0123-456789ABCDEF\""); - } - - TEST(AzureCliCredential, UnsafeChars) - { - std::string const Exploit = std::string("\" | echo OWNED | ") + InfiniteCommand + " | echo \""; + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + trc.TenantId = InvalidValue; + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + try { - AzureCliCredentialOptions options; - options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.TenantId += Exploit; - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + auto const token = azCliCred.GetToken(trc, {}); } - + catch (AuthenticationException const& e) { - AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::hours(24); - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + Exploit); - - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); } } - class ParameterizedTestForDisallowedChars : public ::testing::TestWithParam { - protected: - std::string value; - }; - - TEST_P(ParameterizedTestForDisallowedChars, DisallowedCharsForScopeAndTenantId) + // Tenant ID test via TokenRequestContext, using a specific AdditionallyAllowedTenants value. { - std::string const InvalidValue = GetParam(); + AzureCliCredentialOptions options; + options.AdditionallyAllowedTenants.push_back(InvalidValue); + AzureCliCredential azCliCred(options); - // Tenant ID test via AzureCliCredentialOptions directly. + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + trc.TenantId = InvalidValue; + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + + try { - AzureCliCredentialOptions options; - options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.TenantId += InvalidValue; - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); - } + auto const token = azCliCred.GetToken(trc, {}); } - - // Tenant ID test via TokenRequestContext, using a wildcard for AdditionallyAllowedTenants. + catch (AuthenticationException const& e) { - AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::hours(24); - options.AdditionallyAllowedTenants.push_back("*"); - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - trc.TenantId = InvalidValue; - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); - } - } - - // Tenant ID test via TokenRequestContext, using a specific AdditionallyAllowedTenants value. - { - AzureCliCredentialOptions options; - options.AdditionallyAllowedTenants.push_back(InvalidValue); - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - trc.TenantId = InvalidValue; - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); - } - } - - // Scopes test via TokenRequestContext. - { - AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::hours(24); - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + InvalidValue); - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); - } + EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); } } - INSTANTIATE_TEST_SUITE_P( - AzureCliCredential, - ParameterizedTestForDisallowedChars, - ::testing::Values(" ", "|", "`", "\"", "'", ";", "&")); - - class ParameterizedTestForCharDifferences : public ::testing::TestWithParam { - protected: - std::string value; - }; - - TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId) + // Scopes test via TokenRequestContext. { - std::string const ValidScopeButNotTenantId = GetParam(); + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliCredential azCliCred(options); + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + InvalidValue); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + + try { - AzureCliCredentialOptions options; - options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.TenantId += ValidScopeButNotTenantId; - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); - - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); - } + auto const token = azCliCred.GetToken(trc, {}); } - + catch (AuthenticationException const& e) { - AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::hours(24); - AzureCliCredential azCliCred(options); + EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); + } + } +} - TokenRequestContext trc; - trc.Scopes.push_back( - std::string("https://storage.azure.com/.default") + ValidScopeButNotTenantId); +INSTANTIATE_TEST_SUITE_P( + AzureCliCredential, + ParameterizedTestForDisallowedChars, + ::testing::Values(" ", "|", "`", "\"", "'", ";", "&")); - // We expect the GetToken to fail, but not because of the unsafe chars. - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); - } +class ParameterizedTestForCharDifferences : public ::testing::TestWithParam { +protected: + std::string value; +}; + +TEST_P(ParameterizedTestForCharDifferences, ValidCharsForScopeButNotTenantId) +{ + std::string const ValidScopeButNotTenantId = GetParam(); + + { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.TenantId += ValidScopeButNotTenantId; + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + EXPECT_THROW(static_cast(azCliCred.GetToken(trc, {})), AuthenticationException); + + try + { + auto const token = azCliCred.GetToken(trc, {}); + } + catch (AuthenticationException const& e) + { + EXPECT_TRUE(std::string(e.what()).find("Unsafe") != std::string::npos) << e.what(); } } - INSTANTIATE_TEST_SUITE_P( - AzureCliCredential, - ParameterizedTestForCharDifferences, - ::testing::Values(":", "/", "_")); - - class ParameterizedTestForAllowedChars : public ::testing::TestWithParam { - protected: - std::string value; - }; - - TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId) { - std::string const ValidChars = GetParam(); + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliCredential azCliCred(options); + TokenRequestContext trc; + trc.Scopes.push_back( + std::string("https://storage.azure.com/.default") + ValidScopeButNotTenantId); + + // We expect the GetToken to fail, but not because of the unsafe chars. + try { - AzureCliCredentialOptions options; - options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; - options.TenantId += ValidChars; - AzureCliCredential azCliCred(options); - - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); - - // We expect the GetToken to fail, but not because of the unsafe chars. - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); - } + auto const token = azCliCred.GetToken(trc, {}); } - + catch (AuthenticationException const& e) { - AzureCliCredentialOptions options; - options.CliProcessTimeout = std::chrono::hours(24); - AzureCliCredential azCliCred(options); + EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); + } + } +} - TokenRequestContext trc; - trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + ValidChars); +INSTANTIATE_TEST_SUITE_P( + AzureCliCredential, + ParameterizedTestForCharDifferences, + ::testing::Values(":", "/", "_")); - // We expect the GetToken to fail, but not because of the unsafe chars. - try - { - auto const token = azCliCred.GetToken(trc, {}); - } - catch (AuthenticationException const& e) - { - EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); - } +class ParameterizedTestForAllowedChars : public ::testing::TestWithParam { +protected: + std::string value; +}; + +TEST_P(ParameterizedTestForAllowedChars, ValidCharsForScopeAndTenantId) +{ + std::string const ValidChars = GetParam(); + + { + AzureCliCredentialOptions options; + options.TenantId = "01234567-89AB-CDEF-0123-456789ABCDEF"; + options.TenantId += ValidChars; + AzureCliCredential azCliCred(options); + + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default")); + + // We expect the GetToken to fail, but not because of the unsafe chars. + try + { + auto const token = azCliCred.GetToken(trc, {}); + } + catch (AuthenticationException const& e) + { + EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); } } - INSTANTIATE_TEST_SUITE_P( - AzureCliCredential, - ParameterizedTestForAllowedChars, - ::testing::Values(".", "-", "A", "9")); - - TEST(AzureCliCredential, StrictIso8601TimeFormat) { - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," - "\"expiresOn\":\"2022-08-24T00:43:08\"}"; // With the "T" + AzureCliCredentialOptions options; + options.CliProcessTimeout = std::chrono::hours(24); + AzureCliCredential azCliCred(options); - AzureCliTestCredential const azCliCred(EchoCommand(Token)); + TokenRequestContext trc; + trc.Scopes.push_back(std::string("https://storage.azure.com/.default") + ValidChars); + + // We expect the GetToken to fail, but not because of the unsafe chars. + try + { + auto const token = azCliCred.GetToken(trc, {}); + } + catch (AuthenticationException const& e) + { + EXPECT_TRUE(std::string(e.what()).find("Unsafe") == std::string::npos) << e.what(); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + AzureCliCredential, + ParameterizedTestForAllowedChars, + ::testing::Values(".", "-", "A", "9")); + +TEST(AzureCliCredential, StrictIso8601TimeFormat) +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2022-08-24T00:43:08\"}"; // With the "T" + + AzureCliTestCredential const azCliCred(EchoCommand(Token)); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + auto const token = azCliCred.GetToken(trc, {}); + + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + EXPECT_EQ( + token.ExpiresOn, + DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); +} + +TEST(AzureCliCredential, LocalTime) +{ + constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," + "\"expiresOn\":\"2023-12-07 00:43:08\"}"; + + { + AzureCliTestCredential azCliCred(EchoCommand(Token)); + azCliCred.SetLocalTimeToUtcDiffSeconds(-28800); // Redmond (no DST) TokenRequestContext trc; trc.Scopes.push_back("https://storage.azure.com/.default"); @@ -626,92 +641,71 @@ namespace Azure { namespace Identity { namespace Test { EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); EXPECT_EQ( - token.ExpiresOn, - DateTime::Parse("2022-08-24T00:43:08.000000Z", DateTime::DateFormat::Rfc3339)); + token.ExpiresOn, DateTime::Parse("2023-12-07T08:43:08Z", DateTime::DateFormat::Rfc3339)); } - TEST(AzureCliCredential, LocalTime) { - constexpr auto Token = "{\"accessToken\":\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"," - "\"expiresOn\":\"2023-12-07 00:43:08\"}"; + AzureCliTestCredential azCliCred(EchoCommand(Token)); + azCliCred.SetLocalTimeToUtcDiffSeconds(7200); // Kyiv (no DST) - { - AzureCliTestCredential azCliCred(EchoCommand(Token)); - azCliCred.SetLocalTimeToUtcDiffSeconds(-28800); // Redmond (no DST) + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + auto const token = azCliCred.GetToken(trc, {}); - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - auto const token = azCliCred.GetToken(trc, {}); + EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - - EXPECT_EQ( - token.ExpiresOn, DateTime::Parse("2023-12-07T08:43:08Z", DateTime::DateFormat::Rfc3339)); - } - - { - AzureCliTestCredential azCliCred(EchoCommand(Token)); - azCliCred.SetLocalTimeToUtcDiffSeconds(7200); // Kyiv (no DST) - - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - auto const token = azCliCred.GetToken(trc, {}); - - EXPECT_EQ(token.Token, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); - - EXPECT_EQ( - token.ExpiresOn, DateTime::Parse("2023-12-06T22:43:08Z", DateTime::DateFormat::Rfc3339)); - } + EXPECT_EQ( + token.ExpiresOn, DateTime::Parse("2023-12-06T22:43:08Z", DateTime::DateFormat::Rfc3339)); } +} - TEST(AzureCliCredential, Diagnosability) +TEST(AzureCliCredential, Diagnosability) +{ { + AzureCliTestCredential const azCliCred( + EchoCommand("az is not recognized as an internal or external command, " + "operable program or batch file.")); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + try { - AzureCliTestCredential const azCliCred( - EchoCommand("az is not recognized as an internal or external command, " - "operable program or batch file.")); - - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - try - { - static_cast(azCliCred.GetToken(trc, {})); - } - catch (AuthenticationException const& e) - { - std::string const expectedMsgStart - = "AzureCliCredential didn't get the token: " - "\"az is not recognized as an internal or external command, " - "operable program or batch file."; - - std::string actualMsgStart = e.what(); - actualMsgStart.resize(expectedMsgStart.length()); - - // It is enough to compare StartsWith() and not deal with - // the entire string due to '/n' and '/r/n' differences. - EXPECT_EQ(actualMsgStart, expectedMsgStart); - } + static_cast(azCliCred.GetToken(trc, {})); } - + catch (AuthenticationException const& e) { - AzureCliTestCredential const azCliCred(EchoCommand("{\"property\":\"value\"}")); + std::string const expectedMsgStart + = "AzureCliCredential didn't get the token: " + "\"az is not recognized as an internal or external command, " + "operable program or batch file."; - TokenRequestContext trc; - trc.Scopes.push_back("https://storage.azure.com/.default"); - try - { - static_cast(azCliCred.GetToken(trc, {})); - } - catch (AuthenticationException const& e) - { - EXPECT_EQ( - e.what(), - std::string("AzureCliCredential didn't get the token: " - "\"Token JSON object: can't find or parse 'accessToken' property.\n" - "See Azure::Core::Diagnostics::Logger for details " - "(https://aka.ms/azsdk/cpp/identity/troubleshooting).\"")); - } + std::string actualMsgStart = e.what(); + actualMsgStart.resize(expectedMsgStart.length()); + + // It is enough to compare StartsWith() and not deal with + // the entire string due to '/n' and '/r/n' differences. + EXPECT_EQ(actualMsgStart, expectedMsgStart); } } + + { + AzureCliTestCredential const azCliCred(EchoCommand("{\"property\":\"value\"}")); + + TokenRequestContext trc; + trc.Scopes.push_back("https://storage.azure.com/.default"); + try + { + static_cast(azCliCred.GetToken(trc, {})); + } + catch (AuthenticationException const& e) + { + EXPECT_EQ( + e.what(), + std::string("AzureCliCredential didn't get the token: " + "\"Token JSON object: can't find or parse 'accessToken' property.\n" + "See Azure::Core::Diagnostics::Logger for details " + "(https://aka.ms/azsdk/cpp/identity/troubleshooting).\"")); + } + } +} #endif // not UWP -}}} // namespace Azure::Identity::Test diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp index 409422081..550dd7150 100644 --- a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -12,109 +12,142 @@ using Azure::DateTime; using Azure::Core::Credentials::AccessToken; using Azure::Identity::_detail::TokenCache; -namespace Azure { namespace Identity { namespace Test { - class TestableTokenCache final : public TokenCache { - public: - using TokenCache::CacheValue; - using TokenCache::m_cache; - using TokenCache::m_cacheMutex; +namespace { +class TestableTokenCache final : public TokenCache { +public: + using TokenCache::CacheValue; + using TokenCache::m_cache; + using TokenCache::m_cacheMutex; - mutable std::function m_onBeforeCacheWriteLock; - mutable std::function m_onBeforeItemWriteLock; + mutable std::function m_onBeforeCacheWriteLock; + mutable std::function m_onBeforeItemWriteLock; - void OnBeforeCacheWriteLock() const override - { - if (m_onBeforeCacheWriteLock != nullptr) - { - m_onBeforeCacheWriteLock(); - } - } - - void OnBeforeItemWriteLock() const override - { - if (m_onBeforeItemWriteLock != nullptr) - { - m_onBeforeItemWriteLock(); - } - } - }; - - using namespace std::chrono_literals; - - TEST(TokenCache, GetReuseRefresh) + void OnBeforeCacheWriteLock() const override { - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; - + if (m_onBeforeCacheWriteLock != nullptr) { - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); - EXPECT_EQ(token1.Token, token2.Token); - } - - { - tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday; - - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T3"; - result.ExpiresOn = Tomorrow + 1min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); - EXPECT_EQ(token.Token, "T3"); + m_onBeforeCacheWriteLock(); } } - TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) + void OnBeforeItemWriteLock() const override + { + if (m_onBeforeItemWriteLock != nullptr) + { + m_onBeforeItemWriteLock(); + } + } +}; +} // namespace + +using namespace std::chrono_literals; + +TEST(TokenCache, GetReuseRefresh) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + { + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); + EXPECT_EQ(token1.Token, token2.Token); + } + + { + tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday; + + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); + EXPECT_EQ(token.Token, "T3"); + } +} + +TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + tokenCache.m_onBeforeCacheWriteLock = [&]() { + tokenCache.m_onBeforeCacheWriteLock = nullptr; + static_cast(tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + }; + + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " + "acquiring cache write lock"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "T1"); +} + +TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) +{ + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + { TestableTokenCache tokenCache; EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - tokenCache.m_onBeforeCacheWriteLock = [&]() { - tokenCache.m_onBeforeCacheWriteLock = nullptr; - static_cast(tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache[{"A", {}}]; + item->AccessToken.Token = "T1"; + item->AccessToken.ExpiresOn = Tomorrow; }; auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " - "acquiring cache write lock"); + "acquiring item write lock"); AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow + 1min; @@ -127,547 +160,511 @@ namespace Azure { namespace Identity { namespace Test { EXPECT_EQ(token.Token, "T1"); } - TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) - { - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; - - { - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - tokenCache.m_onBeforeItemWriteLock = [&]() { - tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache[{"A", {}}]; - item->AccessToken.Token = "T1"; - item->AccessToken.ExpiresOn = Tomorrow; - }; - - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE( - "getNewToken does not get invoked when the fresh value was inserted just before " - "acquiring item write lock"); - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "T1"); - } - - // Same as above, but the token that was inserted is already expired. - { - TestableTokenCache tokenCache; - - tokenCache.m_onBeforeItemWriteLock = [&]() { - tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache[{"A", {}}]; - item->AccessToken.Token = "T3"; - item->AccessToken.ExpiresOn = Yesterday; - }; - - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T4"; - result.ExpiresOn = Tomorrow + 3min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); - EXPECT_EQ(token.Token, "T4"); - } - } - - TEST(TokenCache, ExpiredCleanup) - { - // Expected cleanup points are when cache size is in the Fibonacci sequence: - // 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ... - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; - - TestableTokenCache tokenCache; - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - for (auto i = 1; i <= 35; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - // Simply: we added 34+1 token, none of them has expired. None are expected to be cleaned up. - EXPECT_EQ(tokenCache.m_cache.size(), 35UL); - - // Let's expire 3 of them, with numbers from 1 to 3. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; - } - - // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. - for (auto i = 36; i <= 55; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - - // Count is at 55. Tokens from 1 to 3 are still in cache even though they are expired. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - // One more addition to the cache and cleanup for the expired ones will get triggered. - static_cast(tokenCache.GetToken("56", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - - // We were at 55 before we added 1 more, and now we're at 53. 3 were deleted, 1 was added. - EXPECT_EQ(tokenCache.m_cache.size(), 53UL); - - // Items from 1 to 3 should no longer be in the cache. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - // Let's expire items from 21 all the way up to 56. - for (auto i = 21; i <= 56; ++i) - { - auto const n = std::to_string(i); - tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; - } - - // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get - // to 55 items (with numbers from 2 to 56, and number 1 missing). - for (auto i = 2; i <= 3; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - // Cache is now at 55 again (items from 2 to 56). Adding 1 more will trigger cleanup. - EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - - // Now let's lock some of the items for reading, and some for writing. Cleanup should not block - // on token release, but will simply move on, without doing anything to the ones that were - // locked. Out of 4 locked, two are expired, so they should get cleared under normal - // circumstances, but this time they will remain in the cache. - std::shared_lock readLockForUnexpired( - tokenCache.m_cache[{"2", {}}]->ElementMutex); - - std::shared_lock readLockForExpired( - tokenCache.m_cache[{"54", {}}]->ElementMutex); - - std::unique_lock writeLockForUnexpired( - tokenCache.m_cache[{"3", {}}]->ElementMutex); - - std::unique_lock writeLockForExpired( - tokenCache.m_cache[{"55", {}}]->ElementMutex); - - // Count is at 55. Inserting the 56th element, and it will trigger cleanup. - static_cast(tokenCache.GetToken("1", {}, 2min, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow; - return result; - })); - - // These should be 20 unexpired items + two that are expired but were locked, so 22 total. - EXPECT_EQ(tokenCache.m_cache.size(), 22UL); - - for (auto i = 1; i <= 20; ++i) - { - auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end()); - - EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end()); - - for (auto i = 21; i <= 53; ++i) - { - auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - } - - TEST(TokenCache, MinimumExpiration) + // Same as above, but the token that was inserted is already expired. { TestableTokenCache tokenCache; - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); - EXPECT_EQ(token2.Token, "T2"); - } - - TEST(TokenCache, MultithreadedAccess) - { - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - { - std::shared_lock itemReadLock( - tokenCache.m_cache[{"A", {}}]->ElementMutex); - - { - std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); - - // Parallel threads read both the container and the item we're accessing, and we can - // access it in parallel as well. - auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); - EXPECT_EQ(token2.Token, token1.Token); - } - - // The cache is unlocked, but one item is being read in a parallel thread, which does not - // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() { - AccessToken result; - result.Token = "T3"; - result.ExpiresOn = Tomorrow + 2h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 2UL); - - EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); - EXPECT_EQ(token3.Token, "T3"); - } - - { - std::unique_lock itemWriteLock( - tokenCache.m_cache[{"A", {}}]->ElementMutex); - - // The cache is unlocked, but one item is being written in a parallel thread, which does not - // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() { - AccessToken result; - result.Token = "T4"; - result.ExpiresOn = Tomorrow + 3h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 3UL); - - EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); - EXPECT_EQ(token3.Token, "T4"); - } - } - - using Azure::Core::Context; - using Azure::Core::Http::HttpTransport; - using Azure::Core::Http::RawResponse; - using Azure::Core::Http::Request; - - namespace { - class TestTransport final : public HttpTransport { - int m_attemptNumber = 0; - std::vector m_responseBuf; - - public: - // Returns token response with 3600 seconds expiration (1 hour), and the value of the - // client_secret parameter from the body + attempt number as token value. - std::unique_ptr Send(Request& request, Context const&) override - { - using Azure::Core::Http::HttpStatusCode; - using Azure::Core::IO::BodyStream; - using Azure::Core::IO::MemoryBodyStream; - - ++m_attemptNumber; - - std::string clientSecret; - { - std::string const ClientSecretStart = "client_secret="; - - auto const reqBodyVec = request.GetBodyStream()->ReadToEnd(); - auto const reqBodyStr = std::string(reqBodyVec.cbegin(), reqBodyVec.cend()); - - auto clientSecretStartPos = reqBodyStr.find(ClientSecretStart); - if (clientSecretStartPos != std::string::npos) - { - clientSecretStartPos += ClientSecretStart.size(); - auto const clientSecretEndPos = reqBodyStr.find('&', clientSecretStartPos); - - clientSecret = (clientSecretEndPos == std::string::npos) - ? reqBodyStr.substr(clientSecretStartPos) - : reqBodyStr.substr( - clientSecretStartPos, clientSecretEndPos - clientSecretStartPos); - } - } - - auto const respBodyStr = std::string("{ \"access_token\" : \"") + clientSecret - + std::to_string(m_attemptNumber) + "\", \"expires_in\" : 3600 }"; - - m_responseBuf.assign(respBodyStr.cbegin(), respBodyStr.cend()); - - auto resp = std::make_unique(1, 1, HttpStatusCode::Ok, "OK"); - resp->SetBodyStream(std::make_unique(m_responseBuf)); - return resp; - } + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache[{"A", {}}]; + item->AccessToken.Token = "T3"; + item->AccessToken.ExpiresOn = Yesterday; }; - } // namespace - TEST(TokenCache, PerCredInstance) + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3min; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); + EXPECT_EQ(token.Token, "T4"); + } +} + +TEST(TokenCache, ExpiredCleanup) +{ + // Expected cleanup points are when cache size is in the Fibonacci sequence: + // 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ... + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + TestableTokenCache tokenCache; + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + for (auto i = 1; i <= 35; ++i) { - using Azure::Core::Credentials::TokenCredentialOptions; - using Azure::Core::Credentials::TokenRequestContext; - using Azure::Identity::ClientSecretCredential; - - TokenRequestContext getCached; - getCached.Scopes = {"https://vault.azure.net/.default"}; - getCached.MinimumExpiration = 1s; - - TokenCredentialOptions credOptions; - credOptions.Transport.Transport = std::make_shared(); - - ClientSecretCredential credA("TenantId", "ClientId", "SecretA", credOptions); - ClientSecretCredential credB("TenantId", "ClientId", "SecretB", credOptions); - - { - auto const tokenA1 = credA.GetToken(getCached, {}); // Should populate - EXPECT_EQ(tokenA1.Token, "SecretA1"); - } - - { - auto const tokenA2 = credA.GetToken(getCached, {}); // Should get previously populated value - EXPECT_EQ(tokenA2.Token, "SecretA1"); - } - - { - auto const tokenB = credB.GetToken(getCached, {}); - EXPECT_EQ( - tokenB.Token, - "SecretB2"); // if token cache was shared between instances, the value would be - // "SecretA1" - } - - { - auto const tokenA3 = credA.GetToken(getCached, {}); // Should still get the cached value - EXPECT_EQ(tokenA3.Token, "SecretA1"); - } - - auto getNew = getCached; - getNew.MinimumExpiration += 3600s; - - { - auto const tokenA4 = credA.GetToken(getNew, {}); // Should get the new value - EXPECT_EQ(tokenA4.Token, "SecretA3"); - } - - { - auto const tokenA5 = credA.GetToken(getNew, {}); // Should get the new value - EXPECT_EQ(tokenA5.Token, "SecretA4"); - } - - { - auto const tokenA6 - = credA.GetToken(getCached, {}); // Should get the cached, recently refreshed value - EXPECT_EQ(tokenA6.Token, "SecretA4"); - } + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); } - TEST(TokenCache, TenantId) + // Simply: we added 34+1 token, none of them has expired. None are expected to be cleaned up. + EXPECT_EQ(tokenCache.m_cache.size(), 35UL); + + // Let's expire 3 of them, with numbers from 1 to 3. + for (auto i = 1; i <= 3; ++i) { - TestableTokenCache tokenCache; + auto const n = std::to_string(i); + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; + } - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. + for (auto i = 36; i <= 55; ++i) + { + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + } - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); + + // Count is at 55. Tokens from 1 to 3 are still in cache even though they are expired. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + + // One more addition to the cache and cleanup for the expired ones will get triggered. + static_cast(tokenCache.GetToken("56", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // We were at 55 before we added 1 more, and now we're at 53. 3 were deleted, 1 was added. + EXPECT_EQ(tokenCache.m_cache.size(), 53UL); + + // Items from 1 to 3 should no longer be in the cache. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + + // Let's expire items from 21 all the way up to 56. + for (auto i = 21; i <= 56; ++i) + { + auto const n = std::to_string(i); + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; + } + + // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to + // 55 items (with numbers from 2 to 56, and number 1 missing). + for (auto i = 2; i <= 3; ++i) + { + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + } + + // Cache is now at 55 again (items from 2 to 56). Adding 1 more will trigger cleanup. + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); + + // Now let's lock some of the items for reading, and some for writing. Cleanup should not block on + // token release, but will simply move on, without doing anything to the ones that were locked. + // Out of 4 locked, two are expired, so they should get cleared under normal circumstances, but + // this time they will remain in the cache. + std::shared_lock readLockForUnexpired( + tokenCache.m_cache[{"2", {}}]->ElementMutex); + + std::shared_lock readLockForExpired( + tokenCache.m_cache[{"54", {}}]->ElementMutex); + + std::unique_lock writeLockForUnexpired( + tokenCache.m_cache[{"3", {}}]->ElementMutex); + + std::unique_lock writeLockForExpired( + tokenCache.m_cache[{"55", {}}]->ElementMutex); + + // Count is at 55. Inserting the 56th element, and it will trigger cleanup. + static_cast(tokenCache.GetToken("1", {}, 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // These should be 20 unexpired items + two that are expired but were locked, so 22 total. + EXPECT_EQ(tokenCache.m_cache.size(), 22UL); + + for (auto i = 1; i <= 20; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + + EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end()); + + EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end()); + + for (auto i = 21; i <= 53; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } +} + +TEST(TokenCache, MinimumExpiration) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); + EXPECT_EQ(token2.Token, "T2"); +} + +TEST(TokenCache, MultithreadedAccess) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + { + std::shared_lock itemReadLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); { - auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); + + // Parallel threads read both the container and the item we're accessing, and we can + // access it in parallel as well. + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); AccessToken result; - result.Token = "AX"; - result.ExpiresOn = Tomorrow; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; return result; }); EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AX"); + EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); + EXPECT_EQ(token2.Token, token1.Token); } - { - auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { - AccessToken result; - result.Token = "BX"; - result.ExpiresOn = Tomorrow; - return result; - }); + // The cache is unlocked, but one item is being read in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 2h; + return result; + }); - EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BX"); - } - - { - auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { - AccessToken result; - result.Token = "AY"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 3UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AY"); - } - - { - auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { - AccessToken result; - result.Token = "BY"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BY"); - } - - { - auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "XA"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AX"); - } - - { - auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "XB"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BX"); - } - - { - auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "YA"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AY"); - } - - { - auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "YB"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BY"); - } + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); + EXPECT_EQ(token3.Token, "T3"); } -}}} // namespace Azure::Identity::Test + { + std::unique_lock itemWriteLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); + + // The cache is unlocked, but one item is being written in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); + EXPECT_EQ(token3.Token, "T4"); + } +} + +using Azure::Core::Context; +using Azure::Core::Http::HttpTransport; +using Azure::Core::Http::RawResponse; +using Azure::Core::Http::Request; + +namespace { +class TestTransport final : public HttpTransport { + int m_attemptNumber = 0; + std::vector m_responseBuf; + +public: + // Returns token response with 3600 seconds expiration (1 hour), and the value of the + // client_secret parameter from the body + attempt number as token value. + std::unique_ptr Send(Request& request, Context const&) override + { + using Azure::Core::Http::HttpStatusCode; + using Azure::Core::IO::BodyStream; + using Azure::Core::IO::MemoryBodyStream; + + ++m_attemptNumber; + + std::string clientSecret; + { + std::string const ClientSecretStart = "client_secret="; + + auto const reqBodyVec = request.GetBodyStream()->ReadToEnd(); + auto const reqBodyStr = std::string(reqBodyVec.cbegin(), reqBodyVec.cend()); + + auto clientSecretStartPos = reqBodyStr.find(ClientSecretStart); + if (clientSecretStartPos != std::string::npos) + { + clientSecretStartPos += ClientSecretStart.size(); + auto const clientSecretEndPos = reqBodyStr.find('&', clientSecretStartPos); + + clientSecret = (clientSecretEndPos == std::string::npos) + ? reqBodyStr.substr(clientSecretStartPos) + : reqBodyStr.substr(clientSecretStartPos, clientSecretEndPos - clientSecretStartPos); + } + } + + auto const respBodyStr = std::string("{ \"access_token\" : \"") + clientSecret + + std::to_string(m_attemptNumber) + "\", \"expires_in\" : 3600 }"; + + m_responseBuf.assign(respBodyStr.cbegin(), respBodyStr.cend()); + + auto resp = std::make_unique(1, 1, HttpStatusCode::Ok, "OK"); + resp->SetBodyStream(std::make_unique(m_responseBuf)); + return resp; + } +}; +} // namespace + +TEST(TokenCache, PerCredInstance) +{ + using Azure::Core::Credentials::TokenCredentialOptions; + using Azure::Core::Credentials::TokenRequestContext; + using Azure::Identity::ClientSecretCredential; + + TokenRequestContext getCached; + getCached.Scopes = {"https://vault.azure.net/.default"}; + getCached.MinimumExpiration = 1s; + + TokenCredentialOptions credOptions; + credOptions.Transport.Transport = std::make_shared(); + + ClientSecretCredential credA("TenantId", "ClientId", "SecretA", credOptions); + ClientSecretCredential credB("TenantId", "ClientId", "SecretB", credOptions); + + { + auto const tokenA1 = credA.GetToken(getCached, {}); // Should populate + EXPECT_EQ(tokenA1.Token, "SecretA1"); + } + + { + auto const tokenA2 = credA.GetToken(getCached, {}); // Should get previously populated value + EXPECT_EQ(tokenA2.Token, "SecretA1"); + } + + { + auto const tokenB = credB.GetToken(getCached, {}); + EXPECT_EQ( + tokenB.Token, + "SecretB2"); // if token cache was shared between instances, the value would be + // "SecretA1" + } + + { + auto const tokenA3 = credA.GetToken(getCached, {}); // Should still get the cached value + EXPECT_EQ(tokenA3.Token, "SecretA1"); + } + + auto getNew = getCached; + getNew.MinimumExpiration += 3600s; + + { + auto const tokenA4 = credA.GetToken(getNew, {}); // Should get the new value + EXPECT_EQ(tokenA4.Token, "SecretA3"); + } + + { + auto const tokenA5 = credA.GetToken(getNew, {}); // Should get the new value + EXPECT_EQ(tokenA5.Token, "SecretA4"); + } + + { + auto const tokenA6 + = credA.GetToken(getCached, {}); // Should get the cached, recently refreshed value + EXPECT_EQ(tokenA6.Token, "SecretA4"); + } +} + +TEST(TokenCache, TenantId) +{ + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + AccessToken result; + result.Token = "AX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + AccessToken result; + result.Token = "BX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + AccessToken result; + result.Token = "AY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + AccessToken result; + result.Token = "BY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } +} diff --git a/sdk/keyvault/CMakeLists.txt b/sdk/keyvault/CMakeLists.txt index a0596e20a..eec39c807 100644 --- a/sdk/keyvault/CMakeLists.txt +++ b/sdk/keyvault/CMakeLists.txt @@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) if(BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) endif() add_subdirectory(azure-security-keyvault-keys) diff --git a/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt b/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt index 8f35a7ed9..084e41fb6 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt +++ b/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt @@ -102,7 +102,7 @@ generate_documentation(azure-security-keyvault-certificates ${AZ_LIBRARY_VERSION if(BUILD_TESTING) # define a symbol that enables some test hooks in code - add_compile_definitions(_azure_TESTING_BUILD) + add_compile_definitions(TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES OR FETCH_SOURCE_DEPS) include(AddGoogleTest) diff --git a/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp b/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp index 4b3938b78..d43fc8969 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp @@ -16,14 +16,13 @@ #include #include #include -#include #include #include #include namespace Azure { namespace Security { namespace KeyVault { namespace Certificates { -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) namespace Test { class KeyVaultCertificateClientTest; } @@ -34,10 +33,14 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Certificat * * @details The client supports retrieving KeyVaultCertificate. */ - class CertificateClient final { + class CertificateClient +#if !defined(TESTING_BUILD) + final +#endif + { friend class CreateCertificateOperation; -#if defined(_azure_TESTING_BUILD) +#if defined(TESTING_BUILD) friend class Test::KeyVaultCertificateClientTest; #endif diff --git a/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp b/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp index 633348325..79c5f13fa 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include @@ -35,7 +34,11 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Keys { * Vault. The client supports creating, retrieving, updating, deleting, purging, backing up, * restoring, and listing the KeyVaultKey. */ - class KeyClient _azure_NON_FINAL_FOR_TESTS { + class KeyClient +#if !defined(TESTING_BUILD) + final +#endif + { protected: // Using a shared pipeline for a client to share it with LRO (like delete key) /** @brief the base URL for this keyvault instance. */ diff --git a/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp b/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp index 672feca7c..6cb87e5da 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp @@ -18,7 +18,6 @@ #include #include -#include #include #include @@ -43,7 +42,11 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Secrets { * Vault. The client supports creating, retrieving, updating, deleting, purging, backing up, * restoring, and listing the secret. */ - class SecretClient final { + class SecretClient +#if !defined(TESTING_BUILD) + final +#endif + { private: // Using a shared pipeline for a client to share it with LRO (like delete key)