From f22eafad3031a6d02555efb998bd372b9f83751a Mon Sep 17 00:00:00 2001 From: Anton Kolesnyk <41349689+antkmsft@users.noreply.github.com> Date: Wed, 28 Jun 2023 15:05:47 -0700 Subject: [PATCH] Identity: Credentials to accept a wider variety of token responses (#4740) * Identity: Credentials to accept a wider variety of token responses * Restructure code * GCC warning --------- Co-authored-by: Anton Kolesnyk --- sdk/identity/azure-identity/CHANGELOG.md | 2 + .../src/token_credential_impl.cpp | 122 ++++++--- .../test/ut/token_credential_impl_test.cpp | 257 ++++++++++++++++++ 3 files changed, 347 insertions(+), 34 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index e2e59dc74..0dff59807 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -8,6 +8,8 @@ ### Bugs Fixed +- [#4723]](https://github.com/Azure/azure-sdk-for-cpp/issues/4723) Accept a wider variety of token responses. + ### Other Changes ## 1.5.0 (2023-05-04) diff --git a/sdk/identity/azure-identity/src/token_credential_impl.cpp b/sdk/identity/azure-identity/src/token_credential_impl.cpp index 9a496dab4..8a61d2f12 100644 --- a/sdk/identity/azure-identity/src/token_credential_impl.cpp +++ b/sdk/identity/azure-identity/src/token_credential_impl.cpp @@ -15,8 +15,10 @@ using Azure::Identity::_detail::TokenCredentialImpl; using Azure::Identity::_detail::PackageVersion; +using Azure::DateTime; using Azure::Core::Context; using Azure::Core::Url; +using Azure::Core::_internal::PosixTimeConverter; using Azure::Core::Credentials::AccessToken; using Azure::Core::Credentials::AuthenticationException; using Azure::Core::Credentials::TokenCredentialOptions; @@ -126,7 +128,7 @@ AccessToken TokenCredentialImpl::GetToken( std::string(responseBodyVector.begin(), responseBodyVector.end()), "access_token", "expires_in", - std::string()); + "expires_on"); } catch (AuthenticationException const&) { @@ -139,10 +141,10 @@ AccessToken TokenCredentialImpl::GetToken( } namespace { -[[noreturn]] void ThrowMissingJsonPropertyError(std::string const& propertyName) +[[noreturn]] void ThrowJsonPropertyError(std::string const& propertyName) { throw std::runtime_error( - std::string("Token JSON object: \'") + propertyName + "\' property was not found."); + std::string("Token JSON object: can't find or parse \'") + propertyName + "\' property."); } } // namespace @@ -152,44 +154,96 @@ AccessToken TokenCredentialImpl::ParseToken( std::string const& expiresInPropertyName, std::string const& expiresOnPropertyName) { - try + auto const parsedJson = Azure::Core::Json::_internal::json::parse(jsonString); + + if (!parsedJson.contains(accessTokenPropertyName) + || !parsedJson[accessTokenPropertyName].is_string()) { - auto const parsedJson = Azure::Core::Json::_internal::json::parse(jsonString); + ThrowJsonPropertyError(accessTokenPropertyName); + } - if (!parsedJson.contains(accessTokenPropertyName)) - { - ThrowMissingJsonPropertyError(accessTokenPropertyName); - } - - AccessToken accessToken; - accessToken.Token = parsedJson[accessTokenPropertyName].get(); - - if (parsedJson.contains(expiresInPropertyName)) + AccessToken accessToken; + accessToken.Token = parsedJson[accessTokenPropertyName].get(); + accessToken.ExpiresOn = std::chrono::system_clock::now(); + + if (parsedJson.contains(expiresInPropertyName)) + { + auto const& expiresIn = parsedJson[expiresInPropertyName]; + + if (expiresIn.is_number_unsigned()) { + // 'expires_in' as number (seconds until expiration) accessToken.ExpiresOn - = std::chrono::system_clock::now() - + std::chrono::seconds( - parsedJson[expiresInPropertyName].get()); - } - else if (expiresOnPropertyName.empty()) - { - ThrowMissingJsonPropertyError(expiresInPropertyName); - } - else if (!parsedJson.contains(expiresOnPropertyName)) - { - ThrowMissingJsonPropertyError(expiresOnPropertyName); - } - else - { - accessToken.ExpiresOn = Azure::DateTime::Parse( - parsedJson[expiresOnPropertyName].get(), - Azure::DateTime::DateFormat::Rfc3339); + += std::chrono::seconds(expiresIn.get()); + + return accessToken; } - return accessToken; + if (expiresIn.is_string()) + { + try + { + // 'expires_in' as numeric string (seconds until expiration) + accessToken.ExpiresOn += std::chrono::seconds(std::stoi(expiresIn.get())); + + return accessToken; + } + catch (std::exception const&) + { + // stoi() has thrown, we may throw later. + } + } } - catch (Azure::Core::Json::_internal::json::parse_error const& ex) + + if (expiresOnPropertyName.empty()) { - throw std::runtime_error(std::string("Error parsing token JSON: ") + ex.what()); + // 'expires_in' is undefined, 'expires_on' is not expected. + ThrowJsonPropertyError(expiresInPropertyName); } + + if (parsedJson.contains(expiresOnPropertyName)) + { + auto const& expiresOn = parsedJson[expiresOnPropertyName]; + + if (expiresOn.is_number_unsigned()) + { + // 'expires_on' as number (posix time representing an absolute timestamp) + accessToken.ExpiresOn + = PosixTimeConverter::PosixTimeToDateTime(expiresOn.get()); + + return accessToken; + } + + if (expiresOn.is_string()) + { + auto const expiresOnAsString = expiresOn.get(); + for (auto const& parse : { + std::function([](auto const& s) { + // 'expires_on' as RFC3339 date string (absolute timestamp) + return DateTime::Parse(s, DateTime::DateFormat::Rfc3339); + }), + std::function([](auto const& s) { + // 'expires_on' as numeric string (posix time representing an absolute timestamp) + return PosixTimeConverter::PosixTimeToDateTime(std::stoll(s)); + }), + std::function([](auto const& s) { + // 'expires_on' as RFC1123 date string (absolute timestamp) + return DateTime::Parse(s, DateTime::DateFormat::Rfc1123); + }), + }) + { + try + { + accessToken.ExpiresOn = parse(expiresOnAsString); + return accessToken; + } + catch (std::exception const&) + { + // parse() has thrown, we may throw later. + } + } + } + } + + ThrowJsonPropertyError(expiresOnPropertyName); } diff --git a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp index d8c210bda..6a44b8027 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp @@ -9,6 +9,7 @@ #include +using Azure::DateTime; using Azure::Core::Context; using Azure::Core::Url; using Azure::Core::Credentials::AccessToken; @@ -406,3 +407,259 @@ TEST(TokenCredentialImpl, NullResponse) return token; })); } + +namespace { +std::string MakeTokenResponse( + std::string const& number, + std::string const& expiresInValue, + std::string const& expiresOnValue) +{ + return ("{\"access_token\":\"ACCESSTOKEN" + number + "\"") + + (expiresInValue.empty() ? "" : (",\"expires_in\":" + expiresInValue)) + + (expiresOnValue.empty() ? "" : (",\"expires_on\":" + expiresOnValue)) + "}"; +} +} // namespace + +TEST(TokenCredentialImpl, ExpirationFormats) +{ + auto const actual = CredentialTestHelper::SimulateTokenRequest( + [](auto transport) { + TokenCredentialOptions options; + options.Transport.Transport = transport; + + return std::make_unique( + HttpMethod::Get, Url("https://microsoft.com/"), options); + }, + std::vector(47, {"https://azure.com/.default"}), + std::vector{ + MakeTokenResponse("00", "3600", ""), + MakeTokenResponse("01", "\"3600\"", ""), + MakeTokenResponse("02", "\"unknown format\"", ""), + MakeTokenResponse("03", "\"\"", ""), + MakeTokenResponse("04", "null", ""), + MakeTokenResponse("05", "", "43040261106"), + MakeTokenResponse("06", "", "\"43040261106\""), + MakeTokenResponse("07", "", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("08", "", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("09", "", "\"unknown format\""), + MakeTokenResponse("10", "", "\"\""), + MakeTokenResponse("11", "", "null"), + MakeTokenResponse("12", "3600", "43040261106"), + MakeTokenResponse("13", "3600", "\"43040261106\""), + MakeTokenResponse("14", "3600", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("15", "3600", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("16", "3600", "\"unknown format\""), + MakeTokenResponse("17", "3600", "\"\""), + MakeTokenResponse("18", "3600", "null"), + MakeTokenResponse("19", "\"3600\"", "43040261106"), + MakeTokenResponse("20", "\"3600\"", "\"43040261106\""), + MakeTokenResponse("21", "\"3600\"", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("22", "\"3600\"", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("23", "\"3600\"", "\"unknown format\""), + MakeTokenResponse("24", "\"3600\"", "\"\""), + MakeTokenResponse("25", "\"3600\"", "null"), + MakeTokenResponse("26", "\"unknown format\"", "43040261106"), + MakeTokenResponse("27", "\"unknown format\"", "\"43040261106\""), + MakeTokenResponse("28", "\"unknown format\"", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("29", "\"unknown format\"", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("30", "\"unknown format\"", "\"unknown format\""), + MakeTokenResponse("31", "\"unknown format\"", "\"\""), + MakeTokenResponse("32", "\"unknown format\"", "null"), + MakeTokenResponse("33", "\"\"", "43040261106"), + MakeTokenResponse("34", "\"\"", "\"43040261106\""), + MakeTokenResponse("35", "\"\"", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("36", "\"\"", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("37", "\"\"", "\"unknown format\""), + MakeTokenResponse("38", "\"\"", "\"\""), + MakeTokenResponse("39", "\"\"", "null"), + MakeTokenResponse("40", "null", "43040261106"), + MakeTokenResponse("41", "null", "\"43040261106\""), + MakeTokenResponse("42", "null", "\"3333-11-22T04:05:06.000Z\""), + MakeTokenResponse("43", "null", "\"Sun, 22 Nov 3333 04:05:06 GMT\""), + MakeTokenResponse("44", "null", "\"unknown format\""), + MakeTokenResponse("45", "null", "\"\""), + MakeTokenResponse("46", "null", "null"), + }, + [](auto& credential, auto& tokenRequestContext, auto& context) { + AccessToken token; + token.Token = "FAILED"; + + try + { + token = credential.GetToken(tokenRequestContext, context); + } + catch (AuthenticationException const&) + { + } + + return token; + }); + + EXPECT_EQ(actual.Requests.size(), 47U); + EXPECT_EQ(actual.Responses.size(), 47U); + + auto const& response00 = actual.Responses.at(0); + auto const& response01 = actual.Responses.at(1); + auto const& response02 = actual.Responses.at(2); + auto const& response03 = actual.Responses.at(3); + auto const& response04 = actual.Responses.at(4); + auto const& response05 = actual.Responses.at(5); + auto const& response06 = actual.Responses.at(6); + auto const& response07 = actual.Responses.at(7); + auto const& response08 = actual.Responses.at(8); + auto const& response09 = actual.Responses.at(9); + auto const& response10 = actual.Responses.at(10); + auto const& response11 = actual.Responses.at(11); + auto const& response12 = actual.Responses.at(12); + auto const& response13 = actual.Responses.at(13); + auto const& response14 = actual.Responses.at(14); + auto const& response15 = actual.Responses.at(15); + auto const& response16 = actual.Responses.at(16); + auto const& response17 = actual.Responses.at(17); + auto const& response18 = actual.Responses.at(18); + auto const& response19 = actual.Responses.at(19); + auto const& response20 = actual.Responses.at(20); + auto const& response21 = actual.Responses.at(21); + auto const& response22 = actual.Responses.at(22); + auto const& response23 = actual.Responses.at(23); + auto const& response24 = actual.Responses.at(24); + auto const& response25 = actual.Responses.at(25); + auto const& response26 = actual.Responses.at(26); + auto const& response27 = actual.Responses.at(27); + auto const& response28 = actual.Responses.at(28); + auto const& response29 = actual.Responses.at(29); + auto const& response30 = actual.Responses.at(30); + auto const& response31 = actual.Responses.at(31); + auto const& response32 = actual.Responses.at(32); + auto const& response33 = actual.Responses.at(33); + auto const& response34 = actual.Responses.at(34); + auto const& response35 = actual.Responses.at(35); + auto const& response36 = actual.Responses.at(36); + auto const& response37 = actual.Responses.at(37); + auto const& response38 = actual.Responses.at(38); + auto const& response39 = actual.Responses.at(39); + auto const& response40 = actual.Responses.at(40); + auto const& response41 = actual.Responses.at(41); + auto const& response42 = actual.Responses.at(42); + auto const& response43 = actual.Responses.at(43); + auto const& response44 = actual.Responses.at(44); + auto const& response45 = actual.Responses.at(45); + auto const& response46 = actual.Responses.at(46); + + EXPECT_EQ(response00.AccessToken.Token, "ACCESSTOKEN00"); + EXPECT_EQ(response01.AccessToken.Token, "ACCESSTOKEN01"); + EXPECT_EQ(response02.AccessToken.Token, "FAILED"); + EXPECT_EQ(response03.AccessToken.Token, "FAILED"); + EXPECT_EQ(response04.AccessToken.Token, "FAILED"); + EXPECT_EQ(response05.AccessToken.Token, "ACCESSTOKEN05"); + EXPECT_EQ(response06.AccessToken.Token, "ACCESSTOKEN06"); + EXPECT_EQ(response07.AccessToken.Token, "ACCESSTOKEN07"); + EXPECT_EQ(response08.AccessToken.Token, "ACCESSTOKEN08"); + EXPECT_EQ(response09.AccessToken.Token, "FAILED"); + EXPECT_EQ(response10.AccessToken.Token, "FAILED"); + EXPECT_EQ(response11.AccessToken.Token, "FAILED"); + EXPECT_EQ(response12.AccessToken.Token, "ACCESSTOKEN12"); + EXPECT_EQ(response13.AccessToken.Token, "ACCESSTOKEN13"); + EXPECT_EQ(response14.AccessToken.Token, "ACCESSTOKEN14"); + EXPECT_EQ(response15.AccessToken.Token, "ACCESSTOKEN15"); + EXPECT_EQ(response16.AccessToken.Token, "ACCESSTOKEN16"); + EXPECT_EQ(response17.AccessToken.Token, "ACCESSTOKEN17"); + EXPECT_EQ(response18.AccessToken.Token, "ACCESSTOKEN18"); + EXPECT_EQ(response19.AccessToken.Token, "ACCESSTOKEN19"); + EXPECT_EQ(response20.AccessToken.Token, "ACCESSTOKEN20"); + EXPECT_EQ(response21.AccessToken.Token, "ACCESSTOKEN21"); + EXPECT_EQ(response22.AccessToken.Token, "ACCESSTOKEN22"); + EXPECT_EQ(response23.AccessToken.Token, "ACCESSTOKEN23"); + EXPECT_EQ(response24.AccessToken.Token, "ACCESSTOKEN24"); + EXPECT_EQ(response25.AccessToken.Token, "ACCESSTOKEN25"); + EXPECT_EQ(response26.AccessToken.Token, "ACCESSTOKEN26"); + EXPECT_EQ(response27.AccessToken.Token, "ACCESSTOKEN27"); + EXPECT_EQ(response28.AccessToken.Token, "ACCESSTOKEN28"); + EXPECT_EQ(response29.AccessToken.Token, "ACCESSTOKEN29"); + EXPECT_EQ(response30.AccessToken.Token, "FAILED"); + EXPECT_EQ(response31.AccessToken.Token, "FAILED"); + EXPECT_EQ(response32.AccessToken.Token, "FAILED"); + EXPECT_EQ(response33.AccessToken.Token, "ACCESSTOKEN33"); + EXPECT_EQ(response34.AccessToken.Token, "ACCESSTOKEN34"); + EXPECT_EQ(response35.AccessToken.Token, "ACCESSTOKEN35"); + EXPECT_EQ(response36.AccessToken.Token, "ACCESSTOKEN36"); + EXPECT_EQ(response37.AccessToken.Token, "FAILED"); + EXPECT_EQ(response38.AccessToken.Token, "FAILED"); + EXPECT_EQ(response39.AccessToken.Token, "FAILED"); + EXPECT_EQ(response40.AccessToken.Token, "ACCESSTOKEN40"); + EXPECT_EQ(response41.AccessToken.Token, "ACCESSTOKEN41"); + EXPECT_EQ(response42.AccessToken.Token, "ACCESSTOKEN42"); + EXPECT_EQ(response43.AccessToken.Token, "ACCESSTOKEN43"); + EXPECT_EQ(response44.AccessToken.Token, "FAILED"); + EXPECT_EQ(response45.AccessToken.Token, "FAILED"); + EXPECT_EQ(response46.AccessToken.Token, "FAILED"); + + using namespace std::chrono_literals; + EXPECT_GE(response00.AccessToken.ExpiresOn, response00.EarliestExpiration + 3600s); + EXPECT_LE(response00.AccessToken.ExpiresOn, response00.LatestExpiration + 3600s); + + EXPECT_GE(response01.AccessToken.ExpiresOn, response01.EarliestExpiration + 3600s); + EXPECT_LE(response01.AccessToken.ExpiresOn, response01.LatestExpiration + 3600s); + + EXPECT_EQ(response05.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response06.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response07.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response08.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + + EXPECT_GE(response12.AccessToken.ExpiresOn, response12.EarliestExpiration + 3600s); + EXPECT_LE(response12.AccessToken.ExpiresOn, response12.LatestExpiration + 3600s); + + EXPECT_GE(response13.AccessToken.ExpiresOn, response13.EarliestExpiration + 3600s); + EXPECT_LE(response13.AccessToken.ExpiresOn, response13.LatestExpiration + 3600s); + + EXPECT_GE(response14.AccessToken.ExpiresOn, response14.EarliestExpiration + 3600s); + EXPECT_LE(response14.AccessToken.ExpiresOn, response14.LatestExpiration + 3600s); + + EXPECT_GE(response15.AccessToken.ExpiresOn, response15.EarliestExpiration + 3600s); + EXPECT_LE(response15.AccessToken.ExpiresOn, response15.LatestExpiration + 3600s); + + EXPECT_GE(response16.AccessToken.ExpiresOn, response16.EarliestExpiration + 3600s); + EXPECT_LE(response16.AccessToken.ExpiresOn, response16.LatestExpiration + 3600s); + + EXPECT_GE(response17.AccessToken.ExpiresOn, response17.EarliestExpiration + 3600s); + EXPECT_LE(response17.AccessToken.ExpiresOn, response17.LatestExpiration + 3600s); + + EXPECT_GE(response18.AccessToken.ExpiresOn, response18.EarliestExpiration + 3600s); + EXPECT_LE(response18.AccessToken.ExpiresOn, response18.LatestExpiration + 3600s); + + EXPECT_GE(response19.AccessToken.ExpiresOn, response19.EarliestExpiration + 3600s); + EXPECT_LE(response19.AccessToken.ExpiresOn, response19.LatestExpiration + 3600s); + + EXPECT_GE(response20.AccessToken.ExpiresOn, response20.EarliestExpiration + 3600s); + EXPECT_LE(response20.AccessToken.ExpiresOn, response20.LatestExpiration + 3600s); + + EXPECT_GE(response21.AccessToken.ExpiresOn, response21.EarliestExpiration + 3600s); + EXPECT_LE(response21.AccessToken.ExpiresOn, response21.LatestExpiration + 3600s); + + EXPECT_GE(response22.AccessToken.ExpiresOn, response22.EarliestExpiration + 3600s); + EXPECT_LE(response22.AccessToken.ExpiresOn, response22.LatestExpiration + 3600s); + + EXPECT_GE(response23.AccessToken.ExpiresOn, response23.EarliestExpiration + 3600s); + EXPECT_LE(response23.AccessToken.ExpiresOn, response23.LatestExpiration + 3600s); + + EXPECT_GE(response24.AccessToken.ExpiresOn, response24.EarliestExpiration + 3600s); + EXPECT_LE(response24.AccessToken.ExpiresOn, response24.LatestExpiration + 3600s); + + EXPECT_GE(response25.AccessToken.ExpiresOn, response25.EarliestExpiration + 3600s); + EXPECT_LE(response25.AccessToken.ExpiresOn, response25.LatestExpiration + 3600s); + + EXPECT_EQ(response26.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response27.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response28.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response29.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + + EXPECT_EQ(response33.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response34.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response35.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response36.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + + EXPECT_EQ(response40.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response41.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response42.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); + EXPECT_EQ(response43.AccessToken.ExpiresOn, DateTime(3333, 11, 22, 4, 5, 6)); +}