diff --git a/sdk/core/azure-core/inc/credentials/credentials.hpp b/sdk/core/azure-core/inc/credentials/credentials.hpp index b516c812d..2e2dec588 100644 --- a/sdk/core/azure-core/inc/credentials/credentials.hpp +++ b/sdk/core/azure-core/inc/credentials/credentials.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -14,6 +15,11 @@ namespace core namespace credentials { +namespace detail +{ +class CredentialTest; +} + class Credential { virtual void SetScopes(std::string const& scopes) { (void)scopes; } @@ -31,72 +37,60 @@ protected: class TokenCredential : public Credential { - struct Token - { - public: - std::string Scopes; - std::string TokenString; - std::chrono::system_clock::time_point ExpiresAt; - }; + friend class detail::CredentialTest; + class Token; - Token m_token; - mutable std::mutex m_tokenMutex; + std::shared_ptr m_token; + std::mutex m_mutex; - void SetScopes(std::string const& scopes) override; + std::string UpdateTokenNonThreadSafe(Token& token); - Token GetToken() const; + virtual bool IsTokenExpired(std::chrono::system_clock::time_point const& tokenExpiration) const; - void SetToken( - std::string const& tokenString, - std::chrono::system_clock::time_point const& expiresAt); - - void SetToken(Token const& token); + virtual void RefreshToken( + std::string& newTokenString, + std::chrono::system_clock::time_point& newExpiration) + = 0; public: class Internal; + TokenCredential(TokenCredential const& other); + TokenCredential& operator=(TokenCredential const& other); + protected: TokenCredential() = default; + TokenCredential(TokenCredential const& other, int) : Credential(other) {} - TokenCredential(TokenCredential const& other) : Credential(other), m_token(other.GetToken()) {} - - TokenCredential& operator=(TokenCredential const& other) - { - this->SetToken(other.GetToken()); - return *this; - } + void Init(TokenCredential const& other); + virtual std::string GetToken(); + void ResetToken(); }; class ClientSecretCredential : public TokenCredential { - std::string m_tenantId; - std::string m_clientId; - std::string m_clientSecret; + friend class detail::CredentialTest; + class ClientSecret; + + std::shared_ptr m_clientSecret; + std::mutex m_mutex; + + void SetScopes(std::string const& scopes) override; + + std::string GetToken() override; + + void RefreshToken( + std::string& newTokenString, + std::chrono::system_clock::time_point& newExpiration) override; public: - class Internal; - ClientSecretCredential( std::string const& tenantId, std::string const& clientId, - std::string const& clientSecret) - : m_tenantId(tenantId), m_clientId(clientId), m_clientSecret(clientSecret) - { - } + std::string const& clientSecret); - ClientSecretCredential(ClientSecretCredential const& other) - : TokenCredential(other), m_tenantId(other.m_tenantId), m_clientId(other.m_clientId), - m_clientSecret(other.m_clientSecret) - { - } - - ClientSecretCredential& operator=(ClientSecretCredential const& other) - { - this->m_tenantId = other.m_tenantId; - this->m_clientId = other.m_clientId; - this->m_clientSecret = other.m_clientSecret; - return *this; - } + ClientSecretCredential(ClientSecretCredential const& other); + ClientSecretCredential& operator=(ClientSecretCredential const& other); }; } // namespace credentials diff --git a/sdk/core/azure-core/inc/internal/credentials_internal.hpp b/sdk/core/azure-core/inc/internal/credentials_internal.hpp index ef868a2aa..4ea38bcfe 100644 --- a/sdk/core/azure-core/inc/internal/credentials_internal.hpp +++ b/sdk/core/azure-core/inc/internal/credentials_internal.hpp @@ -5,6 +5,11 @@ #include +#include +#include +#include +#include + namespace azure { namespace core @@ -21,39 +26,48 @@ public: } }; +class TokenCredential::Token +{ + friend class TokenCredential; + friend class detail::CredentialTest; + + std::string m_tokenString; + std::chrono::system_clock::time_point m_expiresAt; + std::mutex m_mutex; +}; + class TokenCredential::Internal { public: - static Token GetToken(TokenCredential const& credential) - { - return credential.GetToken(); - } - - static void SetToken( - TokenCredential& credential, - std::string const& token, - std::chrono::system_clock::time_point const& expiration) - { - credential.SetToken(token, expiration); - } + static std::string GetToken(TokenCredential& credential) { return credential.GetToken(); } }; -class ClientSecretCredential::Internal +class ClientSecretCredential::ClientSecret { + friend class ClientSecretCredential; + friend class detail::CredentialTest; + + std::string m_tenantId; + std::string m_clientId; + std::string m_clientSecret; + std::string m_scopes; + public: - static std::string const& GetTenantId(ClientSecretCredential const& credential) + ClientSecret( + std::string const& tenantId, + std::string const& clientId, + std::string const& clientSecret) + : m_tenantId(tenantId), m_clientId(clientId), m_clientSecret(clientSecret) { - return credential.m_tenantId; } - static std::string const& GetClientId(ClientSecretCredential const& credential) + ClientSecret( + std::string const& tenantId, + std::string const& clientId, + std::string const& clientSecret, + std::string const& scopes) + : m_tenantId(tenantId), m_clientId(clientId), m_clientSecret(clientSecret), m_scopes(scopes) { - return credential.m_clientId; - } - - static std::string const& GetClientSecret(ClientSecretCredential const& credential) - { - return credential.m_clientSecret; } }; diff --git a/sdk/core/azure-core/src/credentials/credentials.cpp b/sdk/core/azure-core/src/credentials/credentials.cpp index 5364e1494..5d10d76a0 100644 --- a/sdk/core/azure-core/src/credentials/credentials.cpp +++ b/sdk/core/azure-core/src/credentials/credentials.cpp @@ -2,33 +2,124 @@ // SPDX-License-Identifier: MIT #include +#include using namespace azure::core::credentials; -void TokenCredential::SetScopes(std::string const& scopes) +std::string TokenCredential::UpdateTokenNonThreadSafe(Token& token) { - std::lock_guard const lock(this->m_tokenMutex); - if (this->m_token.Scopes != scopes) + std::string newTokenString; + std::chrono::system_clock::time_point newExpiration; + + this->RefreshToken(newTokenString, newExpiration); + + token.m_tokenString = newTokenString; + token.m_expiresAt = newExpiration; + + return newTokenString; +} + +bool TokenCredential::IsTokenExpired( + std::chrono::system_clock::time_point const& tokenExpiration) const +{ + return tokenExpiration <= std::chrono::system_clock::now() - std::chrono::minutes(5); +} + +TokenCredential::TokenCredential(TokenCredential const& other) : Credential(other) +{ + this->Init(other); +} + +TokenCredential& TokenCredential::operator=(TokenCredential const& other) +{ + std::lock_guard const thisTokenPtrLock(this->m_mutex); + this->Init(other); + return *this; +} + +void TokenCredential::Init(TokenCredential const& other) +{ + std::lock_guard const otherTokenPtrLock(const_cast(other.m_mutex)); + this->m_token = other.m_token; +} + +std::string TokenCredential::GetToken() +{ + std::lock_guard const tokenPtrLock(this->m_mutex); + + if (!this->m_token) { - this->m_token = { scopes, {}, {} }; + this->m_token = std::make_shared(); + return UpdateTokenNonThreadSafe(*this->m_token); } + + std::lock_guard const tokenLock(this->m_token->m_mutex); + Token& token = *this->m_token; + return this->IsTokenExpired(token.m_expiresAt) ? UpdateTokenNonThreadSafe(token) : token.m_tokenString; } -TokenCredential::Token TokenCredential::GetToken() const +void TokenCredential::ResetToken() { - std::lock_guard const lock(this->m_tokenMutex); - return this->m_token; + std::lock_guard const tokenPtrLock(this->m_mutex); + this->m_token.reset(); } -void TokenCredential::SetToken( - std::string const& tokenString, - std::chrono::system_clock::time_point const& expiresAt) +void ClientSecretCredential::SetScopes(std::string const& scopes) { - this->SetToken({ this->m_token.Scopes, tokenString, expiresAt }); + std::lock_guard const clientSecretPtrLock(this->m_mutex); + + if (scopes == this->m_clientSecret->m_scopes) + return; + + this->TokenCredential::ResetToken(); + + if (this->m_clientSecret.unique()) + { + this->m_clientSecret->m_scopes = scopes; + return; + } + + this->m_clientSecret = std::make_shared( + this->m_clientSecret->m_tenantId, + this->m_clientSecret->m_clientId, + this->m_clientSecret->m_clientSecret, + scopes); } -void TokenCredential::SetToken(Token const& token) +std::string ClientSecretCredential::GetToken() { - std::lock_guard const lock(this->m_tokenMutex); - this->m_token = token; + std::lock_guard const clientSecretPtrLock(this->m_mutex); + return this->TokenCredential::GetToken(); +} + +void ClientSecretCredential::RefreshToken( + std::string& newTokenString, + std::chrono::system_clock::time_point& newExpiration) +{ + // TODO: get token using scopes, tenantId, clientId, and clientSecretId. + (void)newTokenString; + (void)newExpiration; +} + +ClientSecretCredential::ClientSecretCredential( + std::string const& tenantId, + std::string const& clientId, + std::string const& clientSecret) + : m_clientSecret(new ClientSecret(tenantId, clientId, clientSecret)) +{ +} + +ClientSecretCredential::ClientSecretCredential(ClientSecretCredential const& other) + : TokenCredential(other, 0) +{ + std::lock_guard const otherClientSecretPtrLock( + const_cast(other.m_mutex)); + this->TokenCredential::Init(other); +} + +ClientSecretCredential& ClientSecretCredential::operator=(ClientSecretCredential const& other) +{ + std::lock_guard const otherClientSecretPtrLock(this->m_mutex); + this->TokenCredential::operator=(other); + return *this; } diff --git a/sdk/core/azure-core/test/main.cpp b/sdk/core/azure-core/test/main.cpp index 73fcb0253..e0a17bc6e 100644 --- a/sdk/core/azure-core/test/main.cpp +++ b/sdk/core/azure-core/test/main.cpp @@ -150,23 +150,74 @@ TEST(Http_Request, add_path) url + "/path/path2/path3?query=value"); } +class azure::core::credentials::detail::CredentialTest : public ClientSecretCredential +{ +public: + CredentialTest( + std::string const& tenantId, + std::string const& clientId, + std::string const& clientSecret) + : ClientSecretCredential(tenantId, clientId, clientSecret) + { + } + + std::string NewTokenString; + std::chrono::system_clock::time_point NewExpiration; + bool IsExpired; + + std::string GetTenantId() const + { + return this->ClientSecretCredential::m_clientSecret->m_tenantId; + } + + std::string GetClientId() const + { + return this->ClientSecretCredential::m_clientSecret->m_clientId; + } + + std::string GetClientSecret() const + { + return this->ClientSecretCredential::m_clientSecret->m_clientSecret; + } + + std::string GetScopes() const { return this->ClientSecretCredential::m_clientSecret->m_scopes; } + + bool IsTokenPtrNull() const { return !this->TokenCredential::m_token; } + + std::string GetTokenString() const { return this->TokenCredential::m_token->m_tokenString; } + + std::chrono::system_clock::time_point GetExpiration() const + { + return this->TokenCredential::m_token->m_expiresAt; + } + +private: + void RefreshToken( + std::string& newTokenString, + std::chrono::system_clock::time_point& newExpiration) override + { + newTokenString = this->NewTokenString; + newExpiration = this->NewExpiration; + } + + bool IsTokenExpired(std::chrono::system_clock::time_point const&) const override + { + return this->IsExpired; + } +}; + TEST(Credential, ClientSecretCredential) { // Client Secret credential properties - credentials::ClientSecretCredential clientSecretCredential( - "tenantId", "clientId", "clientSecret"); + std::string const tenantId = "tenantId"; + std::string const clientId = "clientId"; + std::string const clientSecret = "clientSecret"; - EXPECT_EQ( - credentials::ClientSecretCredential::Internal::GetTenantId(clientSecretCredential), - "tenantId"); + credentials::detail::CredentialTest clientSecretCredential(tenantId, clientId, clientSecret); - EXPECT_EQ( - credentials::ClientSecretCredential::Internal::GetClientId(clientSecretCredential), - "clientId"); - - EXPECT_EQ( - credentials::ClientSecretCredential::Internal::GetClientSecret(clientSecretCredential), - "clientSecret"); + EXPECT_EQ(clientSecretCredential.GetTenantId(), tenantId); + EXPECT_EQ(clientSecretCredential.GetClientId(), clientId); + EXPECT_EQ(clientSecretCredential.GetClientSecret(), clientSecret); // Token credential { @@ -175,12 +226,7 @@ TEST(Credential, ClientSecretCredential) { // Default values { - auto const initialToken - = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); - - EXPECT_EQ(initialToken.TokenString, emptyString); - EXPECT_EQ(initialToken.Scopes, emptyString); - EXPECT_EQ(initialToken.ExpiresAt, defaultTime); + EXPECT_EQ(clientSecretCredential.IsTokenPtrNull(), true); } { @@ -188,61 +234,129 @@ TEST(Credential, ClientSecretCredential) std::string const scopes = "scope"; { credentials::Credential::Internal::SetScopes(clientSecretCredential, scopes); - - auto const scopedToken - = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); - - EXPECT_EQ(scopedToken.TokenString, emptyString); - EXPECT_EQ(scopedToken.Scopes, scopes); - EXPECT_EQ(scopedToken.ExpiresAt, defaultTime); + EXPECT_EQ(clientSecretCredential.IsTokenPtrNull(), true); } - // Set token + // Get token { - std::string const token = "token"; - auto const recentTime = std::chrono::system_clock::now(); + std::string const olderToken = "olderToken"; + std::string const newToken = "newToken"; + auto const olderTime = defaultTime + std::chrono::minutes(10); + auto const newTime = olderTime + std::chrono::minutes(10); { - credentials::TokenCredential::Internal::SetToken( - clientSecretCredential, token, recentTime); + clientSecretCredential.IsExpired = true; + clientSecretCredential.NewTokenString = olderToken; + clientSecretCredential.NewExpiration = olderTime; - auto const refreshedToken + auto const tokenReceived = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); - EXPECT_EQ(refreshedToken.TokenString, token); - EXPECT_EQ(refreshedToken.Scopes, scopes); - EXPECT_EQ(refreshedToken.ExpiresAt, recentTime); + EXPECT_EQ(clientSecretCredential.IsTokenPtrNull(), false); + EXPECT_EQ(tokenReceived, olderToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), olderToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), scopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), olderTime); + } + + // Attemp to get the token when it is not expired yet + { + clientSecretCredential.IsExpired = false; + clientSecretCredential.NewTokenString = newToken; + clientSecretCredential.NewExpiration = newTime; + + auto const tokenReceived + = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); + + EXPECT_EQ(tokenReceived, olderToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), olderToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), scopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), olderTime); + } + + // Attempt to get token after it expired + { + clientSecretCredential.IsExpired = true; + + auto const tokenReceived + = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); + + EXPECT_EQ(tokenReceived, newToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), newToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), scopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), newTime); + + clientSecretCredential.IsExpired = false; } // Setting the very same scopes set earlier does not reset token { - credentials::Credential::Internal::SetScopes( - clientSecretCredential, std::string(scopes)); + std::string const scopesCopy + = scopes.substr(0, scopes.length() / 2) + scopes.substr(scopes.length() / 2); - auto const rescopedToken + { + auto const scopesPtr = scopes.c_str(); + auto const scopesCopyPtr = scopesCopy.c_str(); + EXPECT_NE(scopesPtr, scopesCopyPtr); + EXPECT_EQ(scopes, scopesCopy); + } + + + credentials::Credential::Internal::SetScopes(clientSecretCredential, scopesCopy); + + EXPECT_EQ(clientSecretCredential.GetTenantId(), tenantId); + EXPECT_EQ(clientSecretCredential.GetClientId(), clientId); + EXPECT_EQ(clientSecretCredential.GetClientSecret(), clientSecret); + + auto const tokenReceived = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); - EXPECT_EQ(rescopedToken.TokenString, token); - EXPECT_EQ(rescopedToken.Scopes, scopes); - EXPECT_EQ(rescopedToken.ExpiresAt, recentTime); + EXPECT_EQ(tokenReceived, newToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), newToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), scopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), newTime); + } + + // Updating scopes does reset the token + { + clientSecretCredential.IsExpired = false; + + std::string const anotherScopes = "anotherScopes"; + std::string const anotherToken = "anotherToken"; + auto const anotherTime = newTime + std::chrono::minutes(10); + + clientSecretCredential.NewTokenString = anotherToken; + clientSecretCredential.NewExpiration = anotherTime; + + auto tokenReceived + = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); + + EXPECT_EQ(tokenReceived, newToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), newToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), scopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), newTime); + + credentials::Credential::Internal::SetScopes( + clientSecretCredential, std::string(anotherScopes)); + + + EXPECT_EQ(clientSecretCredential.GetTenantId(), tenantId); + EXPECT_EQ(clientSecretCredential.GetClientId(), clientId); + EXPECT_EQ(clientSecretCredential.GetClientSecret(), clientSecret); + EXPECT_EQ(clientSecretCredential.GetScopes(), anotherScopes); + EXPECT_EQ(clientSecretCredential.IsTokenPtrNull(), true); + + tokenReceived + = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); + + EXPECT_EQ(clientSecretCredential.IsTokenPtrNull(), false); + EXPECT_EQ(tokenReceived, anotherToken); + EXPECT_EQ(clientSecretCredential.GetTokenString(), anotherToken); + EXPECT_EQ(clientSecretCredential.GetScopes(), anotherScopes); + EXPECT_EQ(clientSecretCredential.GetExpiration(), anotherTime); } } } - - // Updating scopes does reset the token - { - std::string const another_scopes = "another_scopes"; - - credentials::Credential::Internal::SetScopes( - clientSecretCredential, std::string(another_scopes)); - - auto const resetToken - = credentials::TokenCredential::Internal::GetToken(clientSecretCredential); - - EXPECT_EQ(resetToken.TokenString, emptyString); - EXPECT_EQ(resetToken.Scopes, another_scopes); - EXPECT_EQ(resetToken.ExpiresAt, defaultTime); - } } } }