From 84560cb5f77be170596b38e629096d589a3570f1 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 3 May 2021 17:11:53 -0700 Subject: [PATCH] Md5 enhance impl and test (#2157) * re-design hash md5 header and implementation --- .../inc/azure/core/cryptography/hash.hpp | 3 +- sdk/core/azure-core/src/cryptography/md5.cpp | 269 ++++++++++-------- sdk/core/azure-core/test/ut/md5.cpp | 33 ++- 3 files changed, 185 insertions(+), 120 deletions(-) diff --git a/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp b/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp index 8cac5a113..4888120dc 100644 --- a/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp +++ b/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp @@ -9,6 +9,7 @@ #pragma once +#include #include #include #include @@ -135,7 +136,7 @@ namespace Azure { namespace Core { namespace Cryptography { ~Md5Hash() override; private: - void* m_md5Context; + std::unique_ptr m_implementation; /** * @brief Computes the hash value of the specified binary input data, including any previously diff --git a/sdk/core/azure-core/src/cryptography/md5.cpp b/sdk/core/azure-core/src/cryptography/md5.cpp index 62e0e5739..8971a6864 100644 --- a/sdk/core/azure-core/src/cryptography/md5.cpp +++ b/sdk/core/azure-core/src/cryptography/md5.cpp @@ -16,156 +16,189 @@ #include #include -namespace Azure { namespace Core { namespace Cryptography { +namespace { #if defined(AZ_PLATFORM_WINDOWS) - namespace _detail { - struct AlgorithmProviderInstance - { - BCRYPT_ALG_HANDLE Handle; - std::size_t ContextSize; - std::size_t HashLength; +class Md5AlgorithmProvider { +private: + // Make sure the initial state of status is non-successful + NTSTATUS m_status = 0; - AlgorithmProviderInstance() - { - NTSTATUS status = BCryptOpenAlgorithmProvider(&Handle, BCRYPT_MD5_ALGORITHM, nullptr, 0); - if (!BCRYPT_SUCCESS(status)) - { - throw std::runtime_error("BCryptOpenAlgorithmProvider failed"); - } - DWORD objectLength = 0; - DWORD dataLength = 0; - status = BCryptGetProperty( - Handle, - BCRYPT_OBJECT_LENGTH, - reinterpret_cast(&objectLength), - sizeof(objectLength), - &dataLength, - 0); - if (!BCRYPT_SUCCESS(status)) - { - throw std::runtime_error("BCryptGetProperty failed"); - } - ContextSize = objectLength; - DWORD hashLength = 0; - status = BCryptGetProperty( - Handle, - BCRYPT_HASH_LENGTH, - reinterpret_cast(&hashLength), - sizeof(hashLength), - &dataLength, - 0); - if (!BCRYPT_SUCCESS(status)) - { - throw std::runtime_error("BCryptGetProperty failed"); - } - HashLength = hashLength; - } +public: + BCRYPT_ALG_HANDLE Handle; + std::size_t ContextSize; + std::size_t HashLength; - ~AlgorithmProviderInstance() { BCryptCloseAlgorithmProvider(Handle, 0); } - }; - - struct Md5HashContext - { - std::string buffer; - BCRYPT_HASH_HANDLE hashHandle = nullptr; - std::size_t hashLength = 0; - }; - } // namespace _detail - - Md5Hash::Md5Hash() + Md5AlgorithmProvider() { - static _detail::AlgorithmProviderInstance AlgorithmProvider{}; - - _detail::Md5HashContext* md5Context = new _detail::Md5HashContext; - m_md5Context = md5Context; - md5Context->buffer.resize(AlgorithmProvider.ContextSize); - md5Context->hashLength = AlgorithmProvider.HashLength; - - NTSTATUS status = BCryptCreateHash( - AlgorithmProvider.Handle, - &md5Context->hashHandle, - reinterpret_cast(&md5Context->buffer[0]), - static_cast(md5Context->buffer.size()), - nullptr, - 0, - 0); - if (!BCRYPT_SUCCESS(status)) + // open an algorithm handle + if (!BCRYPT_SUCCESS( + m_status = BCryptOpenAlgorithmProvider(&Handle, BCRYPT_MD5_ALGORITHM, nullptr, 0))) { - throw std::runtime_error("BCryptCreateHash failed"); + throw std::runtime_error("BCryptOpenAlgorithmProvider failed with code: " + m_status); + } + + // calculate the size of the buffer to hold the hash object + DWORD objectLength = 0; + DWORD dataLength = 0; + if (!BCRYPT_SUCCESS( + m_status = BCryptGetProperty( + Handle, + BCRYPT_OBJECT_LENGTH, + reinterpret_cast(&objectLength), + sizeof(objectLength), + &dataLength, + 0))) + { + throw std::runtime_error("BCryptGetProperty failed with code: " + m_status); + } + + // calculate the length of the hash + ContextSize = objectLength; + DWORD hashLength = 0; + if (!BCRYPT_SUCCESS( + m_status = BCryptGetProperty( + Handle, + BCRYPT_HASH_LENGTH, + reinterpret_cast(&hashLength), + sizeof(hashLength), + &dataLength, + 0))) + { + throw std::runtime_error("BCryptGetProperty failed with code: " + m_status); + } + HashLength = hashLength; + } + + ~Md5AlgorithmProvider() + { + if (Handle) + { + BCryptCloseAlgorithmProvider(Handle, 0); + } + } +}; + +Md5AlgorithmProvider const& GetMD5AlgorithmProvider() +{ + static Md5AlgorithmProvider instance; + return instance; +} + +class Md5BCrypt : public Azure::Core::Cryptography::Hash { +private: + // Make sure the initial state of status is non-successful + NTSTATUS m_status = 0; + BCRYPT_HASH_HANDLE m_hashHandle = nullptr; + std::size_t m_hashLength = 0; + std::string m_buffer; + + void OnAppend(const uint8_t* data, std::size_t length) + { + if (!BCRYPT_SUCCESS( + m_status = BCryptHashData( + m_hashHandle, + reinterpret_cast(const_cast(data)), + static_cast(length), + 0))) + { + throw std::runtime_error("BCryptHashData failed with code: " + m_status); } } - Md5Hash::~Md5Hash() - { - _detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context); - BCryptDestroyHash(md5Context->hashHandle); - delete md5Context; - } - - void Md5Hash::OnAppend(const uint8_t* data, std::size_t length) - { - _detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context); - - NTSTATUS status = BCryptHashData( - md5Context->hashHandle, - reinterpret_cast(const_cast(data)), - static_cast(length), - 0); - if (!BCRYPT_SUCCESS(status)) - { - throw std::runtime_error("BCryptHashData failed"); - } - } - - std::vector Md5Hash::OnFinal(const uint8_t* data, std::size_t length) + std::vector OnFinal(const uint8_t* data, std::size_t length) { OnAppend(data, length); - _detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context); + std::vector hash; - hash.resize(md5Context->hashLength); - NTSTATUS status = BCryptFinishHash( - md5Context->hashHandle, - reinterpret_cast(&hash[0]), - static_cast(hash.size()), - 0); - if (!BCRYPT_SUCCESS(status)) + hash.resize(m_hashLength); + if (!BCRYPT_SUCCESS( + m_status = BCryptFinishHash( + m_hashHandle, + reinterpret_cast(&hash[0]), + static_cast(hash.size()), + 0))) { - throw std::runtime_error("BCryptFinishHash failed"); + throw std::runtime_error("BCryptFinishHash failed with code: " + m_status); } return hash; } +public: + Md5BCrypt() + { + m_buffer.resize(GetMD5AlgorithmProvider().ContextSize); + m_hashLength = GetMD5AlgorithmProvider().HashLength; + + if (!BCRYPT_SUCCESS( + m_status = BCryptCreateHash( + GetMD5AlgorithmProvider().Handle, + &m_hashHandle, + reinterpret_cast(&m_buffer[0]), + static_cast(m_buffer.size()), + nullptr, + 0, + 0))) + { + throw std::runtime_error("BCryptCreateHash failed with code: " + m_status); + } + } + + ~Md5BCrypt() + { + if (m_hashHandle) + { + BCryptDestroyHash(m_hashHandle); + } + } +}; + +} // namespace +Azure::Core::Cryptography::Md5Hash::Md5Hash() : m_implementation(std::make_unique()) {} + #elif defined(AZ_PLATFORM_POSIX) - Md5Hash::Md5Hash() +class Md5OpenSSL : public Azure::Core::Cryptography::Hash { +private: + std::unique_ptr m_context; + + void OnAppend(const uint8_t* data, std::size_t length) { - MD5_CTX* md5Context = new MD5_CTX; - m_md5Context = md5Context; - MD5_Init(md5Context); + MD5_Update(m_context.get(), data, length); } - Md5Hash::~Md5Hash() + std::vector OnFinal(const uint8_t* data, std::size_t length) { - MD5_CTX* md5Context = static_cast(m_md5Context); - delete md5Context; + OnAppend(data, length); + unsigned char hash[MD5_DIGEST_LENGTH]; + MD5_Final(hash, m_context.get()); + return std::vector(std::begin(hash), std::end(hash)); } +public: + Md5OpenSSL() + { + m_context = std::make_unique(); + MD5_Init(m_context.get()); + } +}; + +} // namespace +Azure::Core::Cryptography::Md5Hash::Md5Hash() : m_implementation(std::make_unique()) {} +#endif + +namespace Azure { namespace Core { namespace Cryptography { + Md5Hash::~Md5Hash() {} + void Md5Hash::OnAppend(const uint8_t* data, std::size_t length) { - MD5_CTX* md5Context = static_cast(m_md5Context); - MD5_Update(md5Context, data, length); + m_implementation->Append(data, length); } std::vector Md5Hash::OnFinal(const uint8_t* data, std::size_t length) { - OnAppend(data, length); - MD5_CTX* md5Context = static_cast(m_md5Context); - unsigned char hash[MD5_DIGEST_LENGTH]; - MD5_Final(hash, md5Context); - return std::vector(std::begin(hash), std::end(hash)); + return m_implementation->Final(data, length); } -#endif }}} // namespace Azure::Core::Cryptography diff --git a/sdk/core/azure-core/test/ut/md5.cpp b/sdk/core/azure-core/test/ut/md5.cpp index 523aa8cf7..3135fe199 100644 --- a/sdk/core/azure-core/test/ut/md5.cpp +++ b/sdk/core/azure-core/test/ut/md5.cpp @@ -4,9 +4,11 @@ #include #include #include +#include #include #include #include +#include #include using namespace Azure::Core::Cryptography; @@ -98,7 +100,7 @@ TEST(Md5Hash, Basic) TEST(Md5Hash, ExpectThrow) { std::string data = ""; - const uint8_t* ptr = reinterpret_cast(data.data()); + const uint8_t* ptr = reinterpret_cast(data.c_str()); Md5Hash instance; EXPECT_THROW(instance.Final(nullptr, 1), std::invalid_argument); @@ -118,3 +120,32 @@ TEST(Md5Hash, CtorDtor) Md5Hash instance; } } + +TEST(Md5Hash, multiThread) +{ + auto hashThreadRoutine = [](int sleepFor) { + Md5Hash instance; + std::string data = ""; + const uint8_t* ptr = reinterpret_cast(data.c_str()); + + std::this_thread::sleep_for(std::chrono::milliseconds(sleepFor)); + + EXPECT_EQ( + Azure::Core::Convert::Base64Encode(instance.Final(ptr, data.length())), + "1B2M2Y8AsgTpgAmY7PhCfg=="); + }; + + constexpr static int size = 100; + std::vector pool; + + // Make 100 threads run after a little sleep + // Each created thread will wait from 0 to 3 milliseconds to start to make threads overlap + for (int counter = 0; counter < size; counter++) + { + pool.emplace_back(std::thread(hashThreadRoutine, counter % 4)); + } + for (int counter = 0; counter < size; counter++) + { + pool[counter].join(); + } +}