azure-sdk-for-cpp/sdk/identity/azure-identity/src/token_cache.cpp
Anton Kolesnyk 9ab6a1f62a
Clean up token cache from expired items on Fibonacci cache sizes instead of 2^Ns (#4180)
* Clean up token cache from expired items on Fibonacci cache sizes instead of 2^N

Co-authored-by: Anton Kolesnyk <antkmsft@users.noreply.github.com>
2022-12-13 18:24:57 -08:00

161 lines
4.4 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/identity/detail/token_cache.hpp"
#include <algorithm>
#include <array>
#include <limits>
#include <mutex>
using Azure::Identity::_detail::TokenCache;
using Azure::DateTime;
using Azure::Core::Credentials::AccessToken;
bool TokenCache::IsFresh(
std::shared_ptr<TokenCache::CacheValue> const& item,
DateTime::duration minimumExpiration,
std::chrono::system_clock::time_point now)
{
return item->AccessToken.ExpiresOn > (DateTime(now) + minimumExpiration);
}
namespace {
template <typename T> bool ShouldCleanUpCacheFromExpiredItems(T cacheSize);
}
std::shared_ptr<TokenCache::CacheValue> TokenCache::GetOrCreateValue(
std::string const& key,
DateTime::duration minimumExpiration) const
{
{
std::shared_lock<std::shared_timed_mutex> cacheReadLock(m_cacheMutex);
auto const found = m_cache.find(key);
if (found != TokenCache::m_cache.end())
{
return found->second;
}
}
#if defined(TESTING_BUILD)
OnBeforeCacheWriteLock();
#endif
std::unique_lock<std::shared_timed_mutex> cacheWriteLock(m_cacheMutex);
// Search cache for the second time, in case the item was inserted between releasing the read lock
// and acquiring the write lock.
auto const found = m_cache.find(key);
if (found != m_cache.end())
{
return found->second;
}
// Clean up cache from expired items.
if (ShouldCleanUpCacheFromExpiredItems(m_cache.size()))
{
auto now = std::chrono::system_clock::now();
auto iter = m_cache.begin();
while (iter != m_cache.end())
{
// Should we end up erasing the element, iterator to current will become invalid, after
// which we can't increment it. So we copy current, and safely advance the loop iterator.
auto const curr = iter;
++iter;
// We will try to obtain a write lock, but in a non-blocking way. We only lock it if no one
// was holding it for read and write at a time. If it's busy in any way, we don't wait, but
// move on.
auto const item = curr->second;
{
std::unique_lock<std::shared_timed_mutex> lock(item->ElementMutex, std::defer_lock);
if (lock.try_lock() && !IsFresh(item, minimumExpiration, now))
{
m_cache.erase(curr);
}
}
}
}
// Insert the blank value value and return it.
return m_cache[key] = std::make_shared<CacheValue>();
}
AccessToken TokenCache::GetToken(
std::string const& scopeString,
DateTime::duration minimumExpiration,
std::function<AccessToken()> const& getNewToken) const
{
auto const item = GetOrCreateValue(scopeString, minimumExpiration);
{
std::shared_lock<std::shared_timed_mutex> itemReadLock(item->ElementMutex);
if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now()))
{
return item->AccessToken;
}
}
#if defined(TESTING_BUILD)
OnBeforeItemWriteLock();
#endif
std::unique_lock<std::shared_timed_mutex> itemWriteLock(item->ElementMutex);
// Check the expiration for the second time, in case it just got updated, after releasing the
// itemReadLock, and before acquiring itemWriteLock.
if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now()))
{
return item->AccessToken;
}
auto const newToken = getNewToken();
item->AccessToken = newToken;
return newToken;
}
namespace {
// Compile-time Fibonacci sequence computation.
// Get() produces a std::array<T> containing the numbers in ascending order.
template <
typename T, // Type
T L = 0, // Left hand side
T R = 1, // Right hand side
size_t N = 0, // Counter (for array)
bool X = ((std::numeric_limits<T>::max() - L) < R)> // Condition to stop (integer overflow of T)
struct SortedFibonacciSequence
{
static constexpr auto Get();
};
template <typename T, T L, T R, size_t N> struct SortedFibonacciSequence<T, L, R, N, true>
{
static constexpr auto Get()
{
std::array<T, N + 1> result{};
result[N] = L;
return result;
}
};
template <typename T, T L, T R, size_t N, bool X>
constexpr auto SortedFibonacciSequence<T, L, R, N, X>::Get()
{
auto result = SortedFibonacciSequence<T, R, R + L, N + 1>::Get();
result[N] = L;
return result;
}
template <typename T> bool ShouldCleanUpCacheFromExpiredItems(T cacheSize)
{
static auto const Fibonacci = SortedFibonacciSequence<T, 1, 2>::Get();
return std::binary_search(Fibonacci.begin(), Fibonacci.end(), cacheSize);
}
} // namespace