From 3d7eaddb9ddd2aed059ade7d7bee1bd6af4b31cc Mon Sep 17 00:00:00 2001 From: Ahson Khan Date: Thu, 29 Feb 2024 20:48:12 -0800 Subject: [PATCH] Use new macros in existing surface area, so that classes marked as final don't have virtual methods. (#5389) * Use new macros in existing surface area, so that classes marked as final don't have virtual methods. * Update doc comments. * Use DOXYGEN_PREDEFINED to expand only the macros we want expanded. * Add the compile definition for more projects. * Address PR feedback. * Make TestableTokenCache a friend class of TokenCache. --- cmake-modules/AzureDoxygen.cmake | 8 + sdk/core/azure-core-amqp/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + sdk/core/azure-core/CMakeLists.txt | 2 + .../inc/azure/core/http/policies/policy.hpp | 25 +- .../inc/azure/core/internal/test_hooks.hpp | 47 + .../azure-core/test/ut/retry_policy_test.cpp | 1598 +++++++++-------- .../azure/identity/azure_cli_credential.hpp | 8 +- .../inc/azure/identity/detail/token_cache.hpp | 26 +- .../inc/azure/identity/dll_import_export.hpp | 10 - .../test/ut/token_cache_test.cpp | 1149 ++++++------ sdk/keyvault/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../certificates/certificate_client.hpp | 7 +- .../inc/azure/keyvault/keys/key_client.hpp | 7 +- .../azure/keyvault/secrets/secret_client.hpp | 7 +- 16 files changed, 1478 insertions(+), 1420 deletions(-) create mode 100644 sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp diff --git a/cmake-modules/AzureDoxygen.cmake b/cmake-modules/AzureDoxygen.cmake index 31b2e2ffd..2a410e038 100644 --- a/cmake-modules/AzureDoxygen.cmake +++ b/cmake-modules/AzureDoxygen.cmake @@ -26,6 +26,14 @@ function(generate_documentation PROJECT_NAME PROJECT_VERSION) # classes and enums directly into the documentation. set(DOXYGEN_INLINE_SOURCES NO) set(DOXYGEN_MARKDOWN_ID_STYLE GITHUB) + # Used to correctly expand macros like _azure_NON_FINAL_FOR_TESTS when generating docs. + # Using EXPAND_ONLY_PREDEF to limit macro expansion to the macros specified with the PREDEFINED tags. + set(DOXYGEN_MACRO_EXPANSION YES) + set(EXPAND_ONLY_PREDEF YES) + set(DOXYGEN_PREDEFINED + _azure_NON_FINAL_FOR_TESTS=final + _azure_VIRTUAL_FOR_TESTS= + ) # Skip generating docs for json, test, samples, and private files. set(DOXYGEN_EXCLUDE_PATTERNS json.hpp diff --git a/sdk/core/azure-core-amqp/CMakeLists.txt b/sdk/core/azure-core-amqp/CMakeLists.txt index 96c4ccce6..a1eb2a0e0 100644 --- a/sdk/core/azure-core-amqp/CMakeLists.txt +++ b/sdk/core/azure-core-amqp/CMakeLists.txt @@ -193,6 +193,7 @@ az_rtti_setup( if(BUILD_TESTING) # define a symbol that enables some test hooks in code add_compile_definitions(TESTING_BUILD) + add_compile_definitions(_azure_TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt index 85aa4f5f3..e78497ea8 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt +++ b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt @@ -97,6 +97,7 @@ endif() if(BUILD_AZURE_CORE_TRACING_OPENTELEMETRY AND BUILD_TESTING) # define a symbol that enables some test hooks in code add_compile_definitions(TESTING_BUILD) + add_compile_definitions(_azure_TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index 24fff3915..850519be8 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -96,6 +96,7 @@ set( inc/azure/core/internal/json/json_optional.hpp inc/azure/core/internal/json/json_serializable.hpp inc/azure/core/internal/strings.hpp + inc/azure/core/internal/test_hooks.hpp inc/azure/core/internal/tracing/service_tracing.hpp inc/azure/core/internal/tracing/tracing_impl.hpp inc/azure/core/internal/unique_handle.hpp @@ -205,6 +206,7 @@ az_rtti_setup( if(BUILD_TESTING) # define a symbol that enables some test hooks in code add_compile_definitions(TESTING_BUILD) + add_compile_definitions(_azure_TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES) include(AddGoogleTest) diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index f9e4ec1ab..447319b1a 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -16,6 +16,7 @@ #include "azure/core/http/transport.hpp" #include "azure/core/internal/http/http_sanitizer.hpp" #include "azure/core/internal/http/user_agent.hpp" +#include "azure/core/internal/test_hooks.hpp" #include "azure/core/uuid.hpp" #include @@ -30,6 +31,14 @@ #include #include +#if defined(_azure_TESTING_BUILD) +// Define the class used from tests to validate retry policy +namespace Azure { namespace Core { namespace Test { + class RetryPolicyTest; + class RetryLogic; +}}} // namespace Azure::Core::Test +#endif + /** * A function that should be implemented and linked to the end-user application in order to override * an HTTP transport implementation provided by Azure SDK with custom implementation. @@ -363,11 +372,13 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { /** * @brief HTTP retry policy. */ - class RetryPolicy -#if !defined(TESTING_BUILD) - final + class RetryPolicy _azure_NON_FINAL_FOR_TESTS : public HttpPolicy { + +#if defined(_azure_TESTING_BUILD) + friend class Azure::Core::Test::RetryPolicyTest; + friend class Azure::Core::Test::RetryLogic; #endif - : public HttpPolicy { + private: RetryOptions m_retryOptions; @@ -402,14 +413,14 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { */ static int32_t GetRetryCount(Context const& context); - protected: - virtual bool ShouldRetryOnTransportFailure( + private: + _azure_VIRTUAL_FOR_TESTS bool ShouldRetryOnTransportFailure( RetryOptions const& retryOptions, int32_t attempt, std::chrono::milliseconds& retryAfter, double jitterFactor = -1) const; - virtual bool ShouldRetryOnResponse( + _azure_VIRTUAL_FOR_TESTS bool ShouldRetryOnResponse( RawResponse const& response, RetryOptions const& retryOptions, int32_t attempt, diff --git a/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp b/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp new file mode 100644 index 000000000..c2f555f51 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/internal/test_hooks.hpp @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * @file + * @brief This file is used to define internal-only macros that are used to control the behavior of + * the Azure SDK when running tests, to allow mocking within tests. + * The macros in this file should NOT be used by anyone outside this repo. + */ + +#pragma once + +// When testing is enabled, we want to make sure that certain classes are not final, so that we can +// mock it. +#if defined(_azure_TESTING_BUILD) + +/** + * @brief If we are testing, we want to make sure that classes are not final, by default. + */ +#if !defined(_azure_NON_FINAL_FOR_TESTS) +#define _azure_NON_FINAL_FOR_TESTS +#endif + +/** + * @brief If we are testing, we want to make sure methods can be made virtual, for mocking. + */ +#if !defined(_azure_VIRTUAL_FOR_TESTS) +#define _azure_VIRTUAL_FOR_TESTS virtual +#endif + +#else + +/** + * @brief If we are not testing, we want to make sure that classes are final, by default. + */ +#if !defined(_azure_NON_FINAL_FOR_TESTS) +#define _azure_NON_FINAL_FOR_TESTS final +#endif + +/** + * @brief If we are not testing, we don't need to make methods virtual for mocking. + */ +#if !defined(_azure_VIRTUAL_FOR_TESTS) +#define _azure_VIRTUAL_FOR_TESTS +#endif + +#endif diff --git a/sdk/core/azure-core/test/ut/retry_policy_test.cpp b/sdk/core/azure-core/test/ut/retry_policy_test.cpp index b482418ce..c18870349 100644 --- a/sdk/core/azure-core/test/ut/retry_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/retry_policy_test.cpp @@ -13,823 +13,825 @@ using namespace Azure::Core::Http; using namespace Azure::Core::Http::Policies; using namespace Azure::Core::Http::Policies::_internal; -namespace { -class TestTransportPolicy final : public HttpPolicy { -private: - std::function()> m_send; - -public: - TestTransportPolicy(std::function()> send) : m_send(send) {} - - std::unique_ptr Send( - Request&, - NextHttpPolicy, - Azure::Core::Context const&) const override - { - return m_send(); - } - - std::unique_ptr Clone() const override - { - return std::make_unique(*this); - } -}; - -class RetryPolicyTest final : public RetryPolicy { -private: - std::function - m_shouldRetryOnTransportFailure; - - std::function< - bool(RawResponse const&, RetryOptions const&, int32_t, std::chrono::milliseconds&, double)> - m_shouldRetryOnResponse; - -public: - bool BaseShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const - { - return RetryPolicy::ShouldRetryOnTransportFailure( - retryOptions, attempt, retryAfter, jitterFactor); - } - - bool BaseShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const - { - return RetryPolicy::ShouldRetryOnResponse( - response, retryOptions, attempt, retryAfter, jitterFactor); - } - - RetryPolicyTest( - RetryOptions const& retryOptions, - decltype(m_shouldRetryOnTransportFailure) shouldRetryOnTransportFailure, - decltype(m_shouldRetryOnResponse) shouldRetryOnResponse) - : RetryPolicy(retryOptions), - m_shouldRetryOnTransportFailure( - shouldRetryOnTransportFailure != nullptr // - ? shouldRetryOnTransportFailure - : static_cast( // - [this](auto options, auto attempt, auto retryAfter, auto jitter) { - retryAfter = std::chrono::milliseconds(0); - auto ignore = decltype(retryAfter)(); - return this->BaseShouldRetryOnTransportFailure( - options, attempt, ignore, jitter); - })), - m_shouldRetryOnResponse( - shouldRetryOnResponse != nullptr // - ? shouldRetryOnResponse - : static_cast( // - [this]( - RawResponse const& response, - auto options, - auto attempt, - auto retryAfter, - auto jitter) { - retryAfter = std::chrono::milliseconds(0); - auto ignore = decltype(retryAfter)(); - return this->BaseShouldRetryOnResponse( - response, options, attempt, ignore, jitter); - })) - { - } - - std::unique_ptr Clone() const override - { - return std::make_unique(*this); - } - -protected: - bool ShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const override - { - return m_shouldRetryOnTransportFailure(retryOptions, attempt, retryAfter, jitterFactor); - } - - bool ShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) const override - { - return m_shouldRetryOnResponse(response, retryOptions, attempt, retryAfter, jitterFactor); - } -}; -} // namespace - -TEST(RetryPolicy, ShouldRetryOnResponse) -{ - using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - - RawResponse const* responsePtrSent = nullptr; - - RawResponse const* responsePtrReceived = nullptr; - RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; - int32_t attemptReceived = -1234; - double jitterReceived = -5678; - - int onTransportFailureInvoked = 0; - int onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](RawResponse const& response, auto options, auto attempt, auto, auto jitter) { - ++onResponseInvoked; - responsePtrReceived = &response; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - })); - - policies.emplace_back(std::make_unique([&]() { - auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - - responsePtrSent = response.get(); - - return response; - })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); - } - - EXPECT_EQ(onTransportFailureInvoked, 0); - EXPECT_EQ(onResponseInvoked, 1); - - EXPECT_NE(responsePtrSent, nullptr); - EXPECT_EQ(responsePtrSent, responsePtrReceived); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 1); - EXPECT_EQ(jitterReceived, -1); - - // 3 attempts - responsePtrSent = nullptr; - - responsePtrReceived = nullptr; - retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; - attemptReceived = -1234; - jitterReceived = -5678; - - onTransportFailureInvoked = 0; - onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](RawResponse const& response, auto options, auto attempt, auto retryAfter, auto jitter) { - ++onResponseInvoked; - responsePtrReceived = &response; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - retryAfter = 1ms; - return onResponseInvoked < 3; - })); - - policies.emplace_back(std::make_unique([&]() { - auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); - - responsePtrSent = response.get(); - - return response; - })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); - } - - EXPECT_EQ(onTransportFailureInvoked, 0); - EXPECT_EQ(onResponseInvoked, 3); - - EXPECT_NE(responsePtrSent, nullptr); - EXPECT_EQ(responsePtrSent, responsePtrReceived); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 3); - EXPECT_EQ(jitterReceived, -1); -} - -TEST(RetryPolicy, ShouldRetryOnTransportFailure) -{ - using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - - RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; - int32_t attemptReceived = -1234; - double jitterReceived = -5678; - - int onTransportFailureInvoked = 0; - int onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - }, - [&](auto, auto options, auto attempt, auto, auto jitter) { - ++onResponseInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return false; - })); - - policies.emplace_back(std::make_unique( - []() -> std::unique_ptr { throw TransportException("Test"); })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); - } - - EXPECT_EQ(onTransportFailureInvoked, 1); - EXPECT_EQ(onResponseInvoked, 0); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 1); - EXPECT_EQ(jitterReceived, -1); - - // 3 attempts - retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; - attemptReceived = -1234; - jitterReceived = -5678; - - onTransportFailureInvoked = 0; - onResponseInvoked = 0; - - { - std::vector> policies; - policies.emplace_back(std::make_unique( - retryOptions, - [&](auto options, auto attempt, auto, auto jitter) { - ++onTransportFailureInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - return onTransportFailureInvoked < 3; - }, - [&](auto, auto options, auto attempt, auto retryAfter, auto jitter) { - ++onResponseInvoked; - retryOptionsReceived = options; - attemptReceived = attempt; - jitterReceived = jitter; - - retryAfter = 1ms; - return false; - })); - - policies.emplace_back(std::make_unique( - []() -> std::unique_ptr { throw TransportException("Test"); })); - - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); - - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); - } - - EXPECT_EQ(onTransportFailureInvoked, 3); - EXPECT_EQ(onResponseInvoked, 0); - - EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); - EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); - EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); - EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); - - EXPECT_EQ(attemptReceived, 3); - EXPECT_EQ(jitterReceived, -1); -} - -namespace { -class RetryLogic final : private RetryPolicy { - RetryLogic() : RetryPolicy(RetryOptions()) {} - ~RetryLogic() {} - - static RetryLogic const g_retryPolicy; - -public: - static bool TestShouldRetryOnTransportFailure( - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) - { - return g_retryPolicy.ShouldRetryOnTransportFailure( - retryOptions, attempt, retryAfter, jitterFactor); - } - - static bool TestShouldRetryOnResponse( - RawResponse const& response, - RetryOptions const& retryOptions, - int32_t attempt, - std::chrono::milliseconds& retryAfter, - double jitterFactor) - { - return g_retryPolicy.ShouldRetryOnResponse( - response, retryOptions, attempt, retryAfter, jitterFactor); - } -}; - -RetryLogic const RetryLogic::g_retryPolicy; -} // namespace - -TEST(RetryPolicy, Exponential) -{ - using namespace std::chrono_literals; - - RetryOptions const options{3, 1s, 2min, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 4s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } -} - -TEST(RetryPolicy, LessThan2Retries) -{ - using namespace std::chrono_literals; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({1, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({0, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({-1, 1s, 2min, {}}, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, false); - } -} - -TEST(RetryPolicy, NotExceedingMaxRetryDelay) -{ - using namespace std::chrono_literals; - - RetryOptions const options{7, 1s, 20s, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 4s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 8s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 5, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 16s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 6, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 20s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 7, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 20s); - } -} - -TEST(RetryPolicy, NotExceedingInt32Max) -{ - using namespace std::chrono_literals; - - RetryOptions const options{35, 1s, 9999999999999s, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 31, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1073741824s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 32, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 33, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 34, retryAfter, 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2147483647s); - } -} - -TEST(RetryPolicy, Jitter) -{ - using namespace std::chrono_literals; - - RetryOptions const options{3, 10s, 20min, {}}; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 8s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 13s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 16s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 26s); - } -} - -TEST(RetryPolicy, JitterExtremes) -{ - using namespace std::chrono_literals; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 1ms, 2min, {}}, 1, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 0ms); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 2ms, 2min, {}}, 1, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1ms); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 2, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 21s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry - = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 3, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 21s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnTransportFailure( - {35, 1s, 9999999999999s, {}}, 33, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 2791728741100ms); - } -} - -TEST(RetryPolicy, HttpStatusCode) -{ - using namespace std::chrono_literals; - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), - {3, 3210s, 3h, {HttpStatusCode::RequestTimeout}}, - 1, - retryAfter, - 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 3210s); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), - {3, 654s, 3h, {HttpStatusCode::Ok}}, - 1, - retryAfter, - 1.0); - - EXPECT_EQ(shouldRetry, false); - } - - { - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - RawResponse(1, 1, HttpStatusCode::Ok, ""), - {3, 987s, 3h, {HttpStatusCode::Ok}}, - 1, - retryAfter, - 1.0); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 987s); - } -} - -TEST(RetryPolicy, RetryAfterMs) -{ - using namespace std::chrono_literals; - - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("rEtRy-aFtEr-mS", "1234"); - - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.3); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 1234ms); - } - - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("X-mS-ReTrY-aFtEr-MS", "5678"); - - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 0.8); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 5678ms); - } -} - -TEST(RetryPolicy, RetryAfter) -{ - using namespace std::chrono_literals; - - { - RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); - response.SetHeader("rEtRy-aFtEr", "90"); - - std::chrono::milliseconds retryAfter{}; - bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( - response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.1); - - EXPECT_EQ(shouldRetry, true); - EXPECT_EQ(retryAfter, 90s); - } -} - -TEST(RetryPolicy, LogMessages) -{ - using Azure::Core::Diagnostics::Logger; - - struct Log - { - struct Entry +namespace Azure { namespace Core { namespace Test { + class TestTransportPolicy final : public HttpPolicy { + private: + std::function()> m_send; + + public: + TestTransportPolicy(std::function()> send) : m_send(send) {} + + std::unique_ptr Send( + Request&, + NextHttpPolicy, + Azure::Core::Context const&) const override { - Logger::Level Level; - std::string Message; - }; - - std::vector Entries; - - Log() - { - Logger::SetLevel(Logger::Level::Informational); - Logger::SetListener([&](auto lvl, auto msg) { Entries.emplace_back(Entry{lvl, msg}); }); + return m_send(); } - ~Log() + std::unique_ptr Clone() const override { - Logger::SetListener(nullptr); - Logger::SetLevel(Logger::Level::Warning); + return std::make_unique(*this); + } + }; + + class RetryPolicyTest final : public RetryPolicy { + private: + std::function + m_shouldRetryOnTransportFailure; + + std::function< + bool(RawResponse const&, RetryOptions const&, int32_t, std::chrono::milliseconds&, double)> + m_shouldRetryOnResponse; + + public: + bool BaseShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const + { + return RetryPolicy::ShouldRetryOnTransportFailure( + retryOptions, attempt, retryAfter, jitterFactor); } - } log; + bool BaseShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const + { + return RetryPolicy::ShouldRetryOnResponse( + response, retryOptions, attempt, retryAfter, jitterFactor); + } + RetryPolicyTest( + RetryOptions const& retryOptions, + decltype(m_shouldRetryOnTransportFailure) shouldRetryOnTransportFailure, + decltype(m_shouldRetryOnResponse) shouldRetryOnResponse) + : RetryPolicy(retryOptions), + m_shouldRetryOnTransportFailure( + shouldRetryOnTransportFailure != nullptr // + ? shouldRetryOnTransportFailure + : static_cast( // + [this](auto options, auto attempt, auto retryAfter, auto jitter) { + retryAfter = std::chrono::milliseconds(0); + auto ignore = decltype(retryAfter)(); + return this->BaseShouldRetryOnTransportFailure( + options, attempt, ignore, jitter); + })), + m_shouldRetryOnResponse( + shouldRetryOnResponse != nullptr // + ? shouldRetryOnResponse + : static_cast( // + [this]( + RawResponse const& response, + auto options, + auto attempt, + auto retryAfter, + auto jitter) { + retryAfter = std::chrono::milliseconds(0); + auto ignore = decltype(retryAfter)(); + return this->BaseShouldRetryOnResponse( + response, options, attempt, ignore, jitter); + })) + { + } + + std::unique_ptr Clone() const override + { + return std::make_unique(*this); + } + + protected: + bool ShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const override + { + return m_shouldRetryOnTransportFailure(retryOptions, attempt, retryAfter, jitterFactor); + } + + bool ShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) const override + { + return m_shouldRetryOnResponse(response, retryOptions, attempt, retryAfter, jitterFactor); + } + }; + + TEST(RetryPolicy, ShouldRetryOnResponse) { using namespace std::chrono_literals; - RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::InternalServerError}}; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - auto requestNumber = 0; + RawResponse const* responsePtrSent = nullptr; - std::vector> policies; - policies.emplace_back(std::make_unique(retryOptions, nullptr, nullptr)); - policies.emplace_back(std::make_unique([&]() { - ++requestNumber; + RawResponse const* responsePtrReceived = nullptr; + RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; + int32_t attemptReceived = -1234; + double jitterReceived = -5678; - if (requestNumber == 1) - { - throw TransportException("Cable Unplugged"); - } + int onTransportFailureInvoked = 0; + int onResponseInvoked = 0; - return std::make_unique( - 1, - 1, - requestNumber == 2 ? HttpStatusCode::InternalServerError - : HttpStatusCode::ServiceUnavailable, - "Test"); - })); + { + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + return false; + }, + [&](RawResponse const& response, auto options, auto attempt, auto, auto jitter) { + ++onResponseInvoked; + responsePtrReceived = &response; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); - pipeline.Send(request, Azure::Core::Context()); + return false; + })); + + policies.emplace_back(std::make_unique([&]() { + auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); + + responsePtrSent = response.get(); + + return response; + })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); + } + + EXPECT_EQ(onTransportFailureInvoked, 0); + EXPECT_EQ(onResponseInvoked, 1); + + EXPECT_NE(responsePtrSent, nullptr); + EXPECT_EQ(responsePtrSent, responsePtrReceived); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 1); + EXPECT_EQ(jitterReceived, -1); + + // 3 attempts + responsePtrSent = nullptr; + + responsePtrReceived = nullptr; + retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; + attemptReceived = -1234; + jitterReceived = -5678; + + onTransportFailureInvoked = 0; + onResponseInvoked = 0; + + { + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; + + return false; + }, + [&](RawResponse const& response, + auto options, + auto attempt, + auto retryAfter, + auto jitter) { + ++onResponseInvoked; + responsePtrReceived = &response; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; + + retryAfter = 1ms; + return onResponseInvoked < 3; + })); + + policies.emplace_back(std::make_unique([&]() { + auto response = std::make_unique(1, 1, HttpStatusCode::Ok, "Test"); + + responsePtrSent = response.get(); + + return response; + })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); + } + + EXPECT_EQ(onTransportFailureInvoked, 0); + EXPECT_EQ(onResponseInvoked, 3); + + EXPECT_NE(responsePtrSent, nullptr); + EXPECT_EQ(responsePtrSent, responsePtrReceived); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 3); + EXPECT_EQ(jitterReceived, -1); } - EXPECT_EQ(log.Entries.size(), 5); + TEST(RetryPolicy, ShouldRetryOnTransportFailure) + { + using namespace std::chrono_literals; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::Ok}}; - EXPECT_EQ(log.Entries[0].Level, Logger::Level::Warning); - EXPECT_EQ(log.Entries[0].Message, "HTTP Transport error: Cable Unplugged"); + RetryOptions retryOptionsReceived{0, 0ms, 0ms, {}}; + int32_t attemptReceived = -1234; + double jitterReceived = -5678; - EXPECT_EQ(log.Entries[1].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[1].Message, "HTTP Retry attempt #1 will be made in 0ms."); + int onTransportFailureInvoked = 0; + int onResponseInvoked = 0; - EXPECT_EQ(log.Entries[2].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[2].Message, "HTTP status code 500 will be retried."); + { + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - EXPECT_EQ(log.Entries[3].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[3].Message, "HTTP Retry attempt #2 will be made in 0ms."); + return false; + }, + [&](auto, auto options, auto attempt, auto, auto jitter) { + ++onResponseInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; - EXPECT_EQ(log.Entries[4].Level, Logger::Level::Informational); - EXPECT_EQ(log.Entries[4].Message, "HTTP status code 503 won't be retried."); -} + return false; + })); + + policies.emplace_back(std::make_unique( + []() -> std::unique_ptr { throw TransportException("Test"); })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); + } + + EXPECT_EQ(onTransportFailureInvoked, 1); + EXPECT_EQ(onResponseInvoked, 0); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 1); + EXPECT_EQ(jitterReceived, -1); + + // 3 attempts + retryOptionsReceived = RetryOptions{0, 0ms, 0ms, {}}; + attemptReceived = -1234; + jitterReceived = -5678; + + onTransportFailureInvoked = 0; + onResponseInvoked = 0; + + { + std::vector> policies; + policies.emplace_back(std::make_unique( + retryOptions, + [&](auto options, auto attempt, auto, auto jitter) { + ++onTransportFailureInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; + + return onTransportFailureInvoked < 3; + }, + [&](auto, auto options, auto attempt, auto retryAfter, auto jitter) { + ++onResponseInvoked; + retryOptionsReceived = options; + attemptReceived = attempt; + jitterReceived = jitter; + + retryAfter = 1ms; + return false; + })); + + policies.emplace_back(std::make_unique( + []() -> std::unique_ptr { throw TransportException("Test"); })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + EXPECT_THROW(pipeline.Send(request, Azure::Core::Context()), TransportException); + } + + EXPECT_EQ(onTransportFailureInvoked, 3); + EXPECT_EQ(onResponseInvoked, 0); + + EXPECT_EQ(retryOptionsReceived.MaxRetries, retryOptions.MaxRetries); + EXPECT_EQ(retryOptionsReceived.RetryDelay, retryOptions.RetryDelay); + EXPECT_EQ(retryOptionsReceived.MaxRetryDelay, retryOptions.MaxRetryDelay); + EXPECT_EQ(retryOptionsReceived.StatusCodes, retryOptions.StatusCodes); + + EXPECT_EQ(attemptReceived, 3); + EXPECT_EQ(jitterReceived, -1); + } + + class RetryLogic final : private RetryPolicy { + RetryLogic() : RetryPolicy(RetryOptions()) {} + ~RetryLogic() {} + + static RetryLogic const g_retryPolicy; + + public: + static bool TestShouldRetryOnTransportFailure( + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) + { + return g_retryPolicy.ShouldRetryOnTransportFailure( + retryOptions, attempt, retryAfter, jitterFactor); + } + + static bool TestShouldRetryOnResponse( + RawResponse const& response, + RetryOptions const& retryOptions, + int32_t attempt, + std::chrono::milliseconds& retryAfter, + double jitterFactor) + { + return g_retryPolicy.ShouldRetryOnResponse( + response, retryOptions, attempt, retryAfter, jitterFactor); + } + }; + + RetryLogic const RetryLogic::g_retryPolicy; + + TEST(RetryPolicy, Exponential) + { + using namespace std::chrono_literals; + + RetryOptions const options{3, 1s, 2min, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 4s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } + } + + TEST(RetryPolicy, LessThan2Retries) + { + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({1, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({0, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({-1, 1s, 2min, {}}, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, false); + } + } + + TEST(RetryPolicy, NotExceedingMaxRetryDelay) + { + using namespace std::chrono_literals; + + RetryOptions const options{7, 1s, 20s, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 3, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 4s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 4, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 8s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 5, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 16s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 6, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 20s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 7, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 20s); + } + } + + TEST(RetryPolicy, NotExceedingInt32Max) + { + using namespace std::chrono_literals; + + RetryOptions const options{35, 1s, 9999999999999s, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 31, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1073741824s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 32, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 33, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 34, retryAfter, 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2147483647s); + } + } + + TEST(RetryPolicy, Jitter) + { + using namespace std::chrono_literals; + + RetryOptions const options{3, 10s, 20min, {}}; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 8s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 1, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 13s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 16s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure(options, 2, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 26s); + } + } + + TEST(RetryPolicy, JitterExtremes) + { + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 1ms, 2min, {}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 0ms); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 2ms, 2min, {}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1ms); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 2, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 21s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry + = RetryLogic::TestShouldRetryOnTransportFailure({3, 10s, 21s, {}}, 3, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 21s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnTransportFailure( + {35, 1s, 9999999999999s, {}}, 33, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 2791728741100ms); + } + } + + TEST(RetryPolicy, HttpStatusCode) + { + using namespace std::chrono_literals; + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), + {3, 3210s, 3h, {HttpStatusCode::RequestTimeout}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 3210s); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::RequestTimeout, ""), + {3, 654s, 3h, {HttpStatusCode::Ok}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, false); + } + + { + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + RawResponse(1, 1, HttpStatusCode::Ok, ""), + {3, 987s, 3h, {HttpStatusCode::Ok}}, + 1, + retryAfter, + 1.0); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 987s); + } + } + + TEST(RetryPolicy, RetryAfterMs) + { + using namespace std::chrono_literals; + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("rEtRy-aFtEr-mS", "1234"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.3); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 1234ms); + } + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("X-mS-ReTrY-aFtEr-MS", "5678"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 0.8); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 5678ms); + } + } + + TEST(RetryPolicy, RetryAfter) + { + using namespace std::chrono_literals; + + { + RawResponse response(1, 1, HttpStatusCode::RequestTimeout, ""); + response.SetHeader("rEtRy-aFtEr", "90"); + + std::chrono::milliseconds retryAfter{}; + bool const shouldRetry = RetryLogic::TestShouldRetryOnResponse( + response, {3, 1s, 2min, {HttpStatusCode::RequestTimeout}}, 1, retryAfter, 1.1); + + EXPECT_EQ(shouldRetry, true); + EXPECT_EQ(retryAfter, 90s); + } + } + + TEST(RetryPolicy, LogMessages) + { + using Azure::Core::Diagnostics::Logger; + + struct Log + { + struct Entry + { + Logger::Level Level; + std::string Message; + }; + + std::vector Entries; + + Log() + { + Logger::SetLevel(Logger::Level::Informational); + Logger::SetListener([&](auto lvl, auto msg) { Entries.emplace_back(Entry{lvl, msg}); }); + } + + ~Log() + { + Logger::SetListener(nullptr); + Logger::SetLevel(Logger::Level::Warning); + } + + } log; + + { + using namespace std::chrono_literals; + RetryOptions const retryOptions{5, 10s, 5min, {HttpStatusCode::InternalServerError}}; + + auto requestNumber = 0; + + std::vector> policies; + policies.emplace_back(std::make_unique(retryOptions, nullptr, nullptr)); + policies.emplace_back(std::make_unique([&]() { + ++requestNumber; + + if (requestNumber == 1) + { + throw TransportException("Cable Unplugged"); + } + + return std::make_unique( + 1, + 1, + requestNumber == 2 ? HttpStatusCode::InternalServerError + : HttpStatusCode::ServiceUnavailable, + "Test"); + })); + + Azure::Core::Http::_internal::HttpPipeline pipeline(policies); + + Request request(HttpMethod::Get, Azure::Core::Url("https://www.microsoft.com")); + pipeline.Send(request, Azure::Core::Context()); + } + + EXPECT_EQ(log.Entries.size(), 5); + + EXPECT_EQ(log.Entries[0].Level, Logger::Level::Warning); + EXPECT_EQ(log.Entries[0].Message, "HTTP Transport error: Cable Unplugged"); + + EXPECT_EQ(log.Entries[1].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[1].Message, "HTTP Retry attempt #1 will be made in 0ms."); + + EXPECT_EQ(log.Entries[2].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[2].Message, "HTTP status code 500 will be retried."); + + EXPECT_EQ(log.Entries[3].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[3].Message, "HTTP Retry attempt #2 will be made in 0ms."); + + EXPECT_EQ(log.Entries[4].Level, Logger::Level::Informational); + EXPECT_EQ(log.Entries[4].Message, "HTTP status code 503 won't be retried."); + } +}}} // namespace Azure::Core::Test diff --git a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp index 5ff6f1ece..be24fd050 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/azure_cli_credential.hpp @@ -9,11 +9,11 @@ #pragma once #include "azure/identity/detail/token_cache.hpp" -#include "azure/identity/dll_import_export.hpp" #include #include #include +#include #include #include @@ -56,11 +56,7 @@ namespace Azure { namespace Identity { * @brief Enables authentication to Microsoft Entra ID using Azure CLI to obtain an access * token. */ - class AzureCliCredential -#if !defined(_azure_TESTING_BUILD) - final -#endif - : public Core::Credentials::TokenCredential { + class AzureCliCredential _azure_NON_FINAL_FOR_TESTS : public Core::Credentials::TokenCredential { #if defined(_azure_TESTING_BUILD) friend class Azure::Identity::Test::AzureCliTestCredential; diff --git a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp index 6ca49b82b..97994364c 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/detail/token_cache.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include @@ -19,26 +20,30 @@ #include #include +#if defined(_azure_TESTING_BUILD) +// Define the class used from tests to validate retry policy +namespace Azure { namespace Identity { namespace Test { + class TestableTokenCache; +}}} // namespace Azure::Identity::Test +#endif + namespace Azure { namespace Identity { namespace _detail { /** * @brief Access token cache. * */ - class TokenCache -#if !defined(TESTING_BUILD) - final + class TokenCache _azure_NON_FINAL_FOR_TESTS { + +#if defined(_azure_TESTING_BUILD) + friend class Azure::Identity::Test::TestableTokenCache; #endif - { -#if !defined(TESTING_BUILD) + private: -#else - protected: -#endif // A test hook that gets invoked before cache write lock gets acquired. - virtual void OnBeforeCacheWriteLock() const {}; + _azure_VIRTUAL_FOR_TESTS void OnBeforeCacheWriteLock() const {}; // A test hook that gets invoked before item write lock gets acquired. - virtual void OnBeforeItemWriteLock() const {}; + _azure_VIRTUAL_FOR_TESTS void OnBeforeItemWriteLock() const {}; struct CacheKey { @@ -63,7 +68,6 @@ namespace Azure { namespace Identity { namespace _detail { mutable std::map, CacheKeyComparator> m_cache; mutable std::shared_timed_mutex m_cacheMutex; - private: TokenCache(TokenCache const&) = delete; TokenCache& operator=(TokenCache const&) = delete; diff --git a/sdk/identity/azure-identity/inc/azure/identity/dll_import_export.hpp b/sdk/identity/azure-identity/inc/azure/identity/dll_import_export.hpp index 11f02aeff..6b01515fb 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/dll_import_export.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/dll_import_export.hpp @@ -38,16 +38,6 @@ #undef AZ_IDENTITY_BUILT_AS_DLL -#if defined(_azure_TESTING_BUILD) -#if !defined(_azure_VIRTUAL_FOR_TESTS) -#define _azure_VIRTUAL_FOR_TESTS virtual -#endif -#else -#if !defined(_azure_VIRTUAL_FOR_TESTS) -#define _azure_VIRTUAL_FOR_TESTS -#endif -#endif - /** * @brief Azure SDK abstractions. * diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp index 550dd7150..409422081 100644 --- a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -12,142 +12,109 @@ using Azure::DateTime; using Azure::Core::Credentials::AccessToken; using Azure::Identity::_detail::TokenCache; -namespace { -class TestableTokenCache final : public TokenCache { -public: - using TokenCache::CacheValue; - using TokenCache::m_cache; - using TokenCache::m_cacheMutex; +namespace Azure { namespace Identity { namespace Test { + class TestableTokenCache final : public TokenCache { + public: + using TokenCache::CacheValue; + using TokenCache::m_cache; + using TokenCache::m_cacheMutex; - mutable std::function m_onBeforeCacheWriteLock; - mutable std::function m_onBeforeItemWriteLock; + mutable std::function m_onBeforeCacheWriteLock; + mutable std::function m_onBeforeItemWriteLock; - void OnBeforeCacheWriteLock() const override - { - if (m_onBeforeCacheWriteLock != nullptr) + void OnBeforeCacheWriteLock() const override { - m_onBeforeCacheWriteLock(); + if (m_onBeforeCacheWriteLock != nullptr) + { + m_onBeforeCacheWriteLock(); + } } - } - void OnBeforeItemWriteLock() const override - { - if (m_onBeforeItemWriteLock != nullptr) + void OnBeforeItemWriteLock() const override { - m_onBeforeItemWriteLock(); + if (m_onBeforeItemWriteLock != nullptr) + { + m_onBeforeItemWriteLock(); + } } - } -}; -} // namespace - -using namespace std::chrono_literals; - -TEST(TokenCache, GetReuseRefresh) -{ - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; - - { - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); - EXPECT_EQ(token1.Token, token2.Token); - } - - { - tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday; - - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T3"; - result.ExpiresOn = Tomorrow + 1min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); - EXPECT_EQ(token.Token, "T3"); - } -} - -TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) -{ - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - tokenCache.m_onBeforeCacheWriteLock = [&]() { - tokenCache.m_onBeforeCacheWriteLock = nullptr; - static_cast(tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); }; - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " - "acquiring cache write lock"); - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "T1"); -} - -TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) -{ - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; + using namespace std::chrono_literals; + TEST(TokenCache, GetReuseRefresh) { TestableTokenCache tokenCache; EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - tokenCache.m_onBeforeItemWriteLock = [&]() { - tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache[{"A", {}}]; - item->AccessToken.Token = "T1"; - item->AccessToken.ExpiresOn = Tomorrow; + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + { + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); + EXPECT_EQ(token1.Token, token2.Token); + } + + { + tokenCache.m_cache[{"A", {}}]->AccessToken.ExpiresOn = Yesterday; + + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); + EXPECT_EQ(token.Token, "T3"); + } + } + + TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) + { + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + tokenCache.m_onBeforeCacheWriteLock = [&]() { + tokenCache.m_onBeforeCacheWriteLock = nullptr; + static_cast(tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); }; auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " - "acquiring item write lock"); + "acquiring cache write lock"); AccessToken result; result.Token = "T2"; result.ExpiresOn = Tomorrow + 1min; @@ -160,511 +127,547 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) EXPECT_EQ(token.Token, "T1"); } - // Same as above, but the token that was inserted is already expired. + TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) { - TestableTokenCache tokenCache; - - tokenCache.m_onBeforeItemWriteLock = [&]() { - tokenCache.m_onBeforeItemWriteLock = nullptr; - auto const item = tokenCache.m_cache[{"A", {}}]; - item->AccessToken.Token = "T3"; - item->AccessToken.ExpiresOn = Yesterday; - }; - - auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T4"; - result.ExpiresOn = Tomorrow + 3min; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); - EXPECT_EQ(token.Token, "T4"); - } -} - -TEST(TokenCache, ExpiredCleanup) -{ - // Expected cleanup points are when cache size is in the Fibonacci sequence: - // 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ... - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - auto const Yesterday = Tomorrow - 48h; - - TestableTokenCache tokenCache; - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - for (auto i = 1; i <= 35; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - // Simply: we added 34+1 token, none of them has expired. None are expected to be cleaned up. - EXPECT_EQ(tokenCache.m_cache.size(), 35UL); - - // Let's expire 3 of them, with numbers from 1 to 3. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; - } - - // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. - for (auto i = 36; i <= 55; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - - // Count is at 55. Tokens from 1 to 3 are still in cache even though they are expired. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - // One more addition to the cache and cleanup for the expired ones will get triggered. - static_cast(tokenCache.GetToken("56", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - })); - - // We were at 55 before we added 1 more, and now we're at 53. 3 were deleted, 1 was added. - EXPECT_EQ(tokenCache.m_cache.size(), 53UL); - - // Items from 1 to 3 should no longer be in the cache. - for (auto i = 1; i <= 3; ++i) - { - auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - // Let's expire items from 21 all the way up to 56. - for (auto i = 21; i <= 56; ++i) - { - auto const n = std::to_string(i); - tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; - } - - // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to - // 55 items (with numbers from 2 to 56, and number 1 missing). - for (auto i = 2; i <= 3; ++i) - { - auto const n = std::to_string(i); - static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow; - return result; - })); - } - - // Cache is now at 55 again (items from 2 to 56). Adding 1 more will trigger cleanup. - EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - - // Now let's lock some of the items for reading, and some for writing. Cleanup should not block on - // token release, but will simply move on, without doing anything to the ones that were locked. - // Out of 4 locked, two are expired, so they should get cleared under normal circumstances, but - // this time they will remain in the cache. - std::shared_lock readLockForUnexpired( - tokenCache.m_cache[{"2", {}}]->ElementMutex); - - std::shared_lock readLockForExpired( - tokenCache.m_cache[{"54", {}}]->ElementMutex); - - std::unique_lock writeLockForUnexpired( - tokenCache.m_cache[{"3", {}}]->ElementMutex); - - std::unique_lock writeLockForExpired( - tokenCache.m_cache[{"55", {}}]->ElementMutex); - - // Count is at 55. Inserting the 56th element, and it will trigger cleanup. - static_cast(tokenCache.GetToken("1", {}, 2min, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow; - return result; - })); - - // These should be 20 unexpired items + two that are expired but were locked, so 22 total. - EXPECT_EQ(tokenCache.m_cache.size(), 22UL); - - for (auto i = 1; i <= 20; ++i) - { - auto const n = std::to_string(i); - EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } - - EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end()); - - EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end()); - - for (auto i = 21; i <= 53; ++i) - { - auto const n = std::to_string(i); - EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); - } -} - -TEST(TokenCache, MinimumExpiration) -{ - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() { - AccessToken result; - result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); - EXPECT_EQ(token2.Token, "T2"); -} - -TEST(TokenCache, MultithreadedAccess) -{ - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { - AccessToken result; - result.Token = "T1"; - result.ExpiresOn = Tomorrow; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - - EXPECT_EQ(token1.ExpiresOn, Tomorrow); - EXPECT_EQ(token1.Token, "T1"); - - { - std::shared_lock itemReadLock( - tokenCache.m_cache[{"A", {}}]->ElementMutex); + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; { - std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); + TestableTokenCache tokenCache; - // Parallel threads read both the container and the item we're accessing, and we can - // access it in parallel as well. - auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache[{"A", {}}]; + item->AccessToken.Token = "T1"; + item->AccessToken.ExpiresOn = Tomorrow; + }; + + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE( + "getNewToken does not get invoked when the fresh value was inserted just before " + "acquiring item write lock"); AccessToken result; result.Token = "T2"; - result.ExpiresOn = Tomorrow + 1h; + result.ExpiresOn = Tomorrow + 1min; return result; }); EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); - EXPECT_EQ(token2.Token, token1.Token); + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "T1"); } - // The cache is unlocked, but one item is being read in a parallel thread, which does not - // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() { - AccessToken result; - result.Token = "T3"; - result.ExpiresOn = Tomorrow + 2h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 2UL); - - EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); - EXPECT_EQ(token3.Token, "T3"); - } - - { - std::unique_lock itemWriteLock( - tokenCache.m_cache[{"A", {}}]->ElementMutex); - - // The cache is unlocked, but one item is being written in a parallel thread, which does not - // prevent new items (with different key) from being appended to cache. - auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() { - AccessToken result; - result.Token = "T4"; - result.ExpiresOn = Tomorrow + 3h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 3UL); - - EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); - EXPECT_EQ(token3.Token, "T4"); - } -} - -using Azure::Core::Context; -using Azure::Core::Http::HttpTransport; -using Azure::Core::Http::RawResponse; -using Azure::Core::Http::Request; - -namespace { -class TestTransport final : public HttpTransport { - int m_attemptNumber = 0; - std::vector m_responseBuf; - -public: - // Returns token response with 3600 seconds expiration (1 hour), and the value of the - // client_secret parameter from the body + attempt number as token value. - std::unique_ptr Send(Request& request, Context const&) override - { - using Azure::Core::Http::HttpStatusCode; - using Azure::Core::IO::BodyStream; - using Azure::Core::IO::MemoryBodyStream; - - ++m_attemptNumber; - - std::string clientSecret; + // Same as above, but the token that was inserted is already expired. { - std::string const ClientSecretStart = "client_secret="; + TestableTokenCache tokenCache; - auto const reqBodyVec = request.GetBodyStream()->ReadToEnd(); - auto const reqBodyStr = std::string(reqBodyVec.cbegin(), reqBodyVec.cend()); + tokenCache.m_onBeforeItemWriteLock = [&]() { + tokenCache.m_onBeforeItemWriteLock = nullptr; + auto const item = tokenCache.m_cache[{"A", {}}]; + item->AccessToken.Token = "T3"; + item->AccessToken.ExpiresOn = Yesterday; + }; - auto clientSecretStartPos = reqBodyStr.find(ClientSecretStart); - if (clientSecretStartPos != std::string::npos) - { - clientSecretStartPos += ClientSecretStart.size(); - auto const clientSecretEndPos = reqBodyStr.find('&', clientSecretStartPos); + auto const token = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3min; + return result; + }); - clientSecret = (clientSecretEndPos == std::string::npos) - ? reqBodyStr.substr(clientSecretStartPos) - : reqBodyStr.substr(clientSecretStartPos, clientSecretEndPos - clientSecretStartPos); - } + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); + EXPECT_EQ(token.Token, "T4"); + } + } + + TEST(TokenCache, ExpiredCleanup) + { + // Expected cleanup points are when cache size is in the Fibonacci sequence: + // 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ... + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + TestableTokenCache tokenCache; + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + for (auto i = 1; i <= 35; ++i) + { + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); } - auto const respBodyStr = std::string("{ \"access_token\" : \"") + clientSecret - + std::to_string(m_attemptNumber) + "\", \"expires_in\" : 3600 }"; + // Simply: we added 34+1 token, none of them has expired. None are expected to be cleaned up. + EXPECT_EQ(tokenCache.m_cache.size(), 35UL); - m_responseBuf.assign(respBodyStr.cbegin(), respBodyStr.cend()); + // Let's expire 3 of them, with numbers from 1 to 3. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; + } - auto resp = std::make_unique(1, 1, HttpStatusCode::Ok, "OK"); - resp->SetBodyStream(std::make_unique(m_responseBuf)); - return resp; - } -}; -} // namespace + // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. + for (auto i = 36; i <= 55; ++i) + { + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + } -TEST(TokenCache, PerCredInstance) -{ - using Azure::Core::Credentials::TokenCredentialOptions; - using Azure::Core::Credentials::TokenRequestContext; - using Azure::Identity::ClientSecretCredential; + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - TokenRequestContext getCached; - getCached.Scopes = {"https://vault.azure.net/.default"}; - getCached.MinimumExpiration = 1s; + // Count is at 55. Tokens from 1 to 3 are still in cache even though they are expired. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } - TokenCredentialOptions credOptions; - credOptions.Transport.Transport = std::make_shared(); - - ClientSecretCredential credA("TenantId", "ClientId", "SecretA", credOptions); - ClientSecretCredential credB("TenantId", "ClientId", "SecretB", credOptions); - - { - auto const tokenA1 = credA.GetToken(getCached, {}); // Should populate - EXPECT_EQ(tokenA1.Token, "SecretA1"); - } - - { - auto const tokenA2 = credA.GetToken(getCached, {}); // Should get previously populated value - EXPECT_EQ(tokenA2.Token, "SecretA1"); - } - - { - auto const tokenB = credB.GetToken(getCached, {}); - EXPECT_EQ( - tokenB.Token, - "SecretB2"); // if token cache was shared between instances, the value would be - // "SecretA1" - } - - { - auto const tokenA3 = credA.GetToken(getCached, {}); // Should still get the cached value - EXPECT_EQ(tokenA3.Token, "SecretA1"); - } - - auto getNew = getCached; - getNew.MinimumExpiration += 3600s; - - { - auto const tokenA4 = credA.GetToken(getNew, {}); // Should get the new value - EXPECT_EQ(tokenA4.Token, "SecretA3"); - } - - { - auto const tokenA5 = credA.GetToken(getNew, {}); // Should get the new value - EXPECT_EQ(tokenA5.Token, "SecretA4"); - } - - { - auto const tokenA6 - = credA.GetToken(getCached, {}); // Should get the cached, recently refreshed value - EXPECT_EQ(tokenA6.Token, "SecretA4"); - } -} - -TEST(TokenCache, TenantId) -{ - TestableTokenCache tokenCache; - - EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - - DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; - - { - auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + // One more addition to the cache and cleanup for the expired ones will get triggered. + static_cast(tokenCache.GetToken("56", {}, 2min, [=]() { AccessToken result; - result.Token = "AX"; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // We were at 55 before we added 1 more, and now we're at 53. 3 were deleted, 1 was added. + EXPECT_EQ(tokenCache.m_cache.size(), 53UL); + + // Items from 1 to 3 should no longer be in the cache. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + + // Let's expire items from 21 all the way up to 56. + for (auto i = 21; i <= 56; ++i) + { + auto const n = std::to_string(i); + tokenCache.m_cache[{n, {}}]->AccessToken.ExpiresOn = Yesterday; + } + + // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get + // to 55 items (with numbers from 2 to 56, and number 1 missing). + for (auto i = 2; i <= 3; ++i) + { + auto const n = std::to_string(i); + static_cast(tokenCache.GetToken(n, {}, 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + } + + // Cache is now at 55 again (items from 2 to 56). Adding 1 more will trigger cleanup. + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); + + // Now let's lock some of the items for reading, and some for writing. Cleanup should not block + // on token release, but will simply move on, without doing anything to the ones that were + // locked. Out of 4 locked, two are expired, so they should get cleared under normal + // circumstances, but this time they will remain in the cache. + std::shared_lock readLockForUnexpired( + tokenCache.m_cache[{"2", {}}]->ElementMutex); + + std::shared_lock readLockForExpired( + tokenCache.m_cache[{"54", {}}]->ElementMutex); + + std::unique_lock writeLockForUnexpired( + tokenCache.m_cache[{"3", {}}]->ElementMutex); + + std::unique_lock writeLockForExpired( + tokenCache.m_cache[{"55", {}}]->ElementMutex); + + // Count is at 55. Inserting the 56th element, and it will trigger cleanup. + static_cast(tokenCache.GetToken("1", {}, 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // These should be 20 unexpired items + two that are expired but were locked, so 22 total. + EXPECT_EQ(tokenCache.m_cache.size(), 22UL); + + for (auto i = 1; i <= 20; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + + EXPECT_NE(tokenCache.m_cache.find({"54", {}}), tokenCache.m_cache.end()); + + EXPECT_NE(tokenCache.m_cache.find({"55", {}}), tokenCache.m_cache.end()); + + for (auto i = 21; i <= 53; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(tokenCache.m_cache.find({n, {}}), tokenCache.m_cache.end()); + } + } + + TEST(TokenCache, MinimumExpiration) + { + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { + AccessToken result; + result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; }); EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AX"); + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = tokenCache.GetToken("A", {}, 24h, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); + EXPECT_EQ(token2.Token, "T2"); } + TEST(TokenCache, MultithreadedAccess) { - auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + TestableTokenCache tokenCache; + + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = tokenCache.GetToken("A", {}, 2min, [=]() { AccessToken result; - result.Token = "BX"; + result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; }); - EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BX"); + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + { + std::shared_lock itemReadLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); + + { + std::shared_lock cacheReadLock(tokenCache.m_cacheMutex); + + // Parallel threads read both the container and the item we're accessing, and we can + // access it in parallel as well. + auto const token2 = tokenCache.GetToken("A", {}, 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); + EXPECT_EQ(token2.Token, token1.Token); + } + + // The cache is unlocked, but one item is being read in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = tokenCache.GetToken("B", {}, 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 2h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); + EXPECT_EQ(token3.Token, "T3"); + } + + { + std::unique_lock itemWriteLock( + tokenCache.m_cache[{"A", {}}]->ElementMutex); + + // The cache is unlocked, but one item is being written in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = tokenCache.GetToken("C", {}, 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); + EXPECT_EQ(token3.Token, "T4"); + } } + using Azure::Core::Context; + using Azure::Core::Http::HttpTransport; + using Azure::Core::Http::RawResponse; + using Azure::Core::Http::Request; + + namespace { + class TestTransport final : public HttpTransport { + int m_attemptNumber = 0; + std::vector m_responseBuf; + + public: + // Returns token response with 3600 seconds expiration (1 hour), and the value of the + // client_secret parameter from the body + attempt number as token value. + std::unique_ptr Send(Request& request, Context const&) override + { + using Azure::Core::Http::HttpStatusCode; + using Azure::Core::IO::BodyStream; + using Azure::Core::IO::MemoryBodyStream; + + ++m_attemptNumber; + + std::string clientSecret; + { + std::string const ClientSecretStart = "client_secret="; + + auto const reqBodyVec = request.GetBodyStream()->ReadToEnd(); + auto const reqBodyStr = std::string(reqBodyVec.cbegin(), reqBodyVec.cend()); + + auto clientSecretStartPos = reqBodyStr.find(ClientSecretStart); + if (clientSecretStartPos != std::string::npos) + { + clientSecretStartPos += ClientSecretStart.size(); + auto const clientSecretEndPos = reqBodyStr.find('&', clientSecretStartPos); + + clientSecret = (clientSecretEndPos == std::string::npos) + ? reqBodyStr.substr(clientSecretStartPos) + : reqBodyStr.substr( + clientSecretStartPos, clientSecretEndPos - clientSecretStartPos); + } + } + + auto const respBodyStr = std::string("{ \"access_token\" : \"") + clientSecret + + std::to_string(m_attemptNumber) + "\", \"expires_in\" : 3600 }"; + + m_responseBuf.assign(respBodyStr.cbegin(), respBodyStr.cend()); + + auto resp = std::make_unique(1, 1, HttpStatusCode::Ok, "OK"); + resp->SetBodyStream(std::make_unique(m_responseBuf)); + return resp; + } + }; + } // namespace + + TEST(TokenCache, PerCredInstance) { - auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { - AccessToken result; - result.Token = "AY"; - result.ExpiresOn = Tomorrow; - return result; - }); + using Azure::Core::Credentials::TokenCredentialOptions; + using Azure::Core::Credentials::TokenRequestContext; + using Azure::Identity::ClientSecretCredential; - EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + TokenRequestContext getCached; + getCached.Scopes = {"https://vault.azure.net/.default"}; + getCached.MinimumExpiration = 1s; - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AY"); + TokenCredentialOptions credOptions; + credOptions.Transport.Transport = std::make_shared(); + + ClientSecretCredential credA("TenantId", "ClientId", "SecretA", credOptions); + ClientSecretCredential credB("TenantId", "ClientId", "SecretB", credOptions); + + { + auto const tokenA1 = credA.GetToken(getCached, {}); // Should populate + EXPECT_EQ(tokenA1.Token, "SecretA1"); + } + + { + auto const tokenA2 = credA.GetToken(getCached, {}); // Should get previously populated value + EXPECT_EQ(tokenA2.Token, "SecretA1"); + } + + { + auto const tokenB = credB.GetToken(getCached, {}); + EXPECT_EQ( + tokenB.Token, + "SecretB2"); // if token cache was shared between instances, the value would be + // "SecretA1" + } + + { + auto const tokenA3 = credA.GetToken(getCached, {}); // Should still get the cached value + EXPECT_EQ(tokenA3.Token, "SecretA1"); + } + + auto getNew = getCached; + getNew.MinimumExpiration += 3600s; + + { + auto const tokenA4 = credA.GetToken(getNew, {}); // Should get the new value + EXPECT_EQ(tokenA4.Token, "SecretA3"); + } + + { + auto const tokenA5 = credA.GetToken(getNew, {}); // Should get the new value + EXPECT_EQ(tokenA5.Token, "SecretA4"); + } + + { + auto const tokenA6 + = credA.GetToken(getCached, {}); // Should get the cached, recently refreshed value + EXPECT_EQ(tokenA6.Token, "SecretA4"); + } } + TEST(TokenCache, TenantId) { - auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { - AccessToken result; - result.Token = "BY"; - result.ExpiresOn = Tomorrow; - return result; - }); + TestableTokenCache tokenCache; - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BY"); + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + AccessToken result; + result.Token = "AX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + AccessToken result; + result.Token = "BX"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 2UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + AccessToken result; + result.Token = "AY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 3UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + AccessToken result; + result.Token = "BY"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } + + { + auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AX"); + } + + { + auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "XB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BX"); + } + + { + auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YA"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "AY"); + } + + { + auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "YB"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(tokenCache.m_cache.size(), 4UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "BY"); + } } - { - auto const token = tokenCache.GetToken("A", "X", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "XA"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AX"); - } - - { - auto const token = tokenCache.GetToken("B", "X", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "XB"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BX"); - } - - { - auto const token = tokenCache.GetToken("A", "Y", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "YA"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "AY"); - } - - { - auto const token = tokenCache.GetToken("B", "Y", 2min, [=]() { - EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); - AccessToken result; - result.Token = "YB"; - result.ExpiresOn = Tomorrow + 24h; - return result; - }); - - EXPECT_EQ(tokenCache.m_cache.size(), 4UL); - - EXPECT_EQ(token.ExpiresOn, Tomorrow); - EXPECT_EQ(token.Token, "BY"); - } -} +}}} // namespace Azure::Identity::Test diff --git a/sdk/keyvault/CMakeLists.txt b/sdk/keyvault/CMakeLists.txt index eec39c807..74f985fd4 100644 --- a/sdk/keyvault/CMakeLists.txt +++ b/sdk/keyvault/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) if(BUILD_TESTING) # define a symbol that enables some test hooks in code add_compile_definitions(TESTING_BUILD) + add_compile_definitions(_azure_TESTING_BUILD) endif() add_subdirectory(azure-security-keyvault-keys) diff --git a/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt b/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt index 084e41fb6..b20e05af1 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt +++ b/sdk/keyvault/azure-security-keyvault-certificates/CMakeLists.txt @@ -103,6 +103,7 @@ generate_documentation(azure-security-keyvault-certificates ${AZ_LIBRARY_VERSION if(BUILD_TESTING) # define a symbol that enables some test hooks in code add_compile_definitions(TESTING_BUILD) + add_compile_definitions(_azure_TESTING_BUILD) if (NOT AZ_ALL_LIBRARIES OR FETCH_SOURCE_DEPS) include(AddGoogleTest) diff --git a/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp b/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp index d43fc8969..0c35fbea5 100644 --- a/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-certificates/inc/azure/keyvault/certificates/certificate_client.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -33,11 +34,7 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Certificat * * @details The client supports retrieving KeyVaultCertificate. */ - class CertificateClient -#if !defined(TESTING_BUILD) - final -#endif - { + class CertificateClient _azure_NON_FINAL_FOR_TESTS { friend class CreateCertificateOperation; #if defined(TESTING_BUILD) diff --git a/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp b/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp index 79c5f13fa..633348325 100644 --- a/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-keys/inc/azure/keyvault/keys/key_client.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -34,11 +35,7 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Keys { * Vault. The client supports creating, retrieving, updating, deleting, purging, backing up, * restoring, and listing the KeyVaultKey. */ - class KeyClient -#if !defined(TESTING_BUILD) - final -#endif - { + class KeyClient _azure_NON_FINAL_FOR_TESTS { protected: // Using a shared pipeline for a client to share it with LRO (like delete key) /** @brief the base URL for this keyvault instance. */ diff --git a/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp b/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp index 6cb87e5da..ba39aa20b 100644 --- a/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp +++ b/sdk/keyvault/azure-security-keyvault-secrets/inc/azure/keyvault/secrets/secret_client.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -42,11 +43,7 @@ namespace Azure { namespace Security { namespace KeyVault { namespace Secrets { * Vault. The client supports creating, retrieving, updating, deleting, purging, backing up, * restoring, and listing the secret. */ - class SecretClient -#if !defined(TESTING_BUILD) - final -#endif - { + class SecretClient _azure_NON_FINAL_FOR_TESTS { private: // Using a shared pipeline for a client to share it with LRO (like delete key)