diff --git a/sdk/identity/azure-identity/src/token_cache.cpp b/sdk/identity/azure-identity/src/token_cache.cpp index 18e66cfcf..a69f87773 100644 --- a/sdk/identity/azure-identity/src/token_cache.cpp +++ b/sdk/identity/azure-identity/src/token_cache.cpp @@ -3,6 +3,9 @@ #include "azure/identity/detail/token_cache.hpp" +#include +#include +#include #include using Azure::Identity::_detail::TokenCache; @@ -18,6 +21,10 @@ bool TokenCache::IsFresh( return item->AccessToken.ExpiresOn > (DateTime(now) + minimumExpiration); } +namespace { +template bool ShouldCleanUpCacheFromExpiredItems(T cacheSize); +} + std::shared_ptr TokenCache::GetOrCreateValue( std::string const& key, DateTime::duration minimumExpiration) const @@ -46,38 +53,28 @@ std::shared_ptr TokenCache::GetOrCreateValue( return found->second; } - // Clean up cache from expired items (once every N insertions). + // Clean up cache from expired items. + if (ShouldCleanUpCacheFromExpiredItems(m_cache.size())) { - auto const cacheSize = m_cache.size(); + auto now = std::chrono::system_clock::now(); - // N: cacheSize (before insertion) is >= 32 and is a power of two. - // 32 as a starting point does not have any special meaning. - // - // Power of 2 trick: - // https://www.exploringbinary.com/ten-ways-to-check-if-an-integer-is-a-power-of-two-in-c/ - - if (cacheSize >= 32 && (cacheSize & (cacheSize - 1)) == 0) + auto iter = m_cache.begin(); + while (iter != m_cache.end()) { - auto now = std::chrono::system_clock::now(); + // Should we end up erasing the element, iterator to current will become invalid, after + // which we can't increment it. So we copy current, and safely advance the loop iterator. + auto const curr = iter; + ++iter; - auto iter = m_cache.begin(); - while (iter != m_cache.end()) + // We will try to obtain a write lock, but in a non-blocking way. We only lock it if no one + // was holding it for read and write at a time. If it's busy in any way, we don't wait, but + // move on. + auto const item = curr->second; { - // Should we end up erasing the element, iterator to current will become invalid, after - // which we can't increment it. So we copy current, and safely advance the loop iterator. - auto const curr = iter; - ++iter; - - // We will try to obtain a write lock, but in a non-blocking way. We only lock it if no one - // was holding it for read and write at a time. If it's busy in any way, we don't wait, but - // move on. - auto const item = curr->second; + std::unique_lock lock(item->ElementMutex, std::defer_lock); + if (lock.try_lock() && !IsFresh(item, minimumExpiration, now)) { - std::unique_lock lock(item->ElementMutex, std::defer_lock); - if (lock.try_lock() && !IsFresh(item, minimumExpiration, now)) - { - m_cache.erase(curr); - } + m_cache.erase(curr); } } } @@ -120,3 +117,44 @@ AccessToken TokenCache::GetToken( item->AccessToken = newToken; return newToken; } + +namespace { + +// Compile-time Fibonacci sequence computation. +// Get() produces a std::array containing the numbers in ascending order. +template < + typename T, // Type + T L = 0, // Left hand side + T R = 1, // Right hand side + size_t N = 0, // Counter (for array) + bool X = ((std::numeric_limits::max() - L) < R)> // Condition to stop (integer overflow of T) +struct SortedFibonacciSequence +{ + static constexpr auto Get(); +}; + +template struct SortedFibonacciSequence +{ + static constexpr auto Get() + { + std::array result{}; + result[N] = L; + return result; + } +}; + +template +constexpr auto SortedFibonacciSequence::Get() +{ + auto result = SortedFibonacciSequence::Get(); + result[N] = L; + return result; +} + +template bool ShouldCleanUpCacheFromExpiredItems(T cacheSize) +{ + static auto const Fibonacci = SortedFibonacciSequence::Get(); + return std::binary_search(Fibonacci.begin(), Fibonacci.end(), cacheSize); +} + +} // namespace diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp index cc72e64bf..14cdc95ca 100644 --- a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -186,13 +186,15 @@ TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) TEST(TokenCache, ExpiredCleanup) { + // Expected cleanup points are when cache size is in the Fibonacci sequence: + // 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ... DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; auto const Yesterday = Tomorrow - 48h; TestableTokenCache tokenCache; EXPECT_EQ(tokenCache.m_cache.size(), 0UL); - for (auto i = 1; i <= 65; ++i) + for (auto i = 1; i <= 35; ++i) { auto const n = std::to_string(i); static_cast(tokenCache.GetToken(n, 2min, [=]() { @@ -203,8 +205,8 @@ TEST(TokenCache, ExpiredCleanup) })); } - // Simply: we added 64+1 token, none of them has expired. None are expected to be cleaned up. - EXPECT_EQ(tokenCache.m_cache.size(), 65UL); + // Simply: we added 34+1 token, none of them has expired. None are expected to be cleaned up. + EXPECT_EQ(tokenCache.m_cache.size(), 35UL); // Let's expire 3 of them, with numbers from 1 to 3. for (auto i = 1; i <= 3; ++i) @@ -213,8 +215,8 @@ TEST(TokenCache, ExpiredCleanup) tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; } - // Add tokens up to 128 total. When 129th gets added, clean up should get triggered. - for (auto i = 66; i <= 128; ++i) + // Add tokens up to 55 total. When 56th gets added, clean up should get triggered. + for (auto i = 36; i <= 55; ++i) { auto const n = std::to_string(i); static_cast(tokenCache.GetToken(n, 2min, [=]() { @@ -225,9 +227,9 @@ TEST(TokenCache, ExpiredCleanup) })); } - EXPECT_EQ(tokenCache.m_cache.size(), 128UL); + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); - // Count is at 128. Tokens from 1 to 3 are still in cache even though they are expired. + // Count is at 55. Tokens from 1 to 3 are still in cache even though they are expired. for (auto i = 1; i <= 3; ++i) { auto const n = std::to_string(i); @@ -235,15 +237,15 @@ TEST(TokenCache, ExpiredCleanup) } // One more addition to the cache and cleanup for the expired ones will get triggered. - static_cast(tokenCache.GetToken("129", 2min, [=]() { + static_cast(tokenCache.GetToken("56", 2min, [=]() { AccessToken result; result.Token = "T1"; result.ExpiresOn = Tomorrow; return result; })); - // We were at 128 before we added 1 more, and now we're at 126. 3 were deleted, 1 was added. - EXPECT_EQ(tokenCache.m_cache.size(), 126UL); + // We were at 55 before we added 1 more, and now we're at 53. 3 were deleted, 1 was added. + EXPECT_EQ(tokenCache.m_cache.size(), 53UL); // Items from 1 to 3 should no longer be in the cache. for (auto i = 1; i <= 3; ++i) @@ -252,15 +254,15 @@ TEST(TokenCache, ExpiredCleanup) EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } - // Let's expire items from 21 all the way up to 129. - for (auto i = 21; i <= 129; ++i) + // Let's expire items from 21 all the way up to 56. + for (auto i = 21; i <= 56; ++i) { auto const n = std::to_string(i); tokenCache.m_cache[n]->AccessToken.ExpiresOn = Yesterday; } // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to - // 128 items (with numbers from 2 to 129, and number 1 missing). + // 55 items (with numbers from 2 to 56, and number 1 missing). for (auto i = 2; i <= 3; ++i) { auto const n = std::to_string(i); @@ -272,26 +274,26 @@ TEST(TokenCache, ExpiredCleanup) })); } - // Cache is now at 128 again (items from 2 to 129). Adding 1 more will trigger cleanup. - EXPECT_EQ(tokenCache.m_cache.size(), 128UL); + // Cache is now at 55 again (items from 2 to 56). Adding 1 more will trigger cleanup. + EXPECT_EQ(tokenCache.m_cache.size(), 55UL); // Now let's lock some of the items for reading, and some for writing. Cleanup should not block on // token release, but will simply move on, without doing anything to the ones that were locked. - // Out of 4 locked, two are expired, so they should get cleared under normla circumstances, but + // Out of 4 locked, two are expired, so they should get cleared under normal circumstances, but // this time they will remain in the cache. std::shared_lock readLockForUnexpired( tokenCache.m_cache["2"]->ElementMutex); std::shared_lock readLockForExpired( - tokenCache.m_cache["127"]->ElementMutex); + tokenCache.m_cache["54"]->ElementMutex); std::unique_lock writeLockForUnexpired( tokenCache.m_cache["3"]->ElementMutex); std::unique_lock writeLockForExpired( - tokenCache.m_cache["128"]->ElementMutex); + tokenCache.m_cache["55"]->ElementMutex); - // Count is at 128. Inserting the 129th element, and it will trigger cleanup. + // Count is at 55. Inserting the 56th element, and it will trigger cleanup. static_cast(tokenCache.GetToken("1", 2min, [=]() { AccessToken result; result.Token = "T2"; @@ -308,11 +310,11 @@ TEST(TokenCache, ExpiredCleanup) EXPECT_NE(tokenCache.m_cache.find(n), tokenCache.m_cache.end()); } - EXPECT_NE(tokenCache.m_cache.find("127"), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find("54"), tokenCache.m_cache.end()); - EXPECT_NE(tokenCache.m_cache.find("128"), tokenCache.m_cache.end()); + EXPECT_NE(tokenCache.m_cache.find("55"), tokenCache.m_cache.end()); - for (auto i = 21; i <= 126; ++i) + for (auto i = 21; i <= 53; ++i) { auto const n = std::to_string(i); EXPECT_EQ(tokenCache.m_cache.find(n), tokenCache.m_cache.end());