Md5 enhance impl and test (#2157)
* re-design hash md5 header and implementation
This commit is contained in:
parent
8cf7746b6b
commit
84560cb5f7
@ -9,6 +9,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
@ -135,7 +136,7 @@ namespace Azure { namespace Core { namespace Cryptography {
|
||||
~Md5Hash() override;
|
||||
|
||||
private:
|
||||
void* m_md5Context;
|
||||
std::unique_ptr<Hash> m_implementation;
|
||||
|
||||
/**
|
||||
* @brief Computes the hash value of the specified binary input data, including any previously
|
||||
|
||||
@ -16,156 +16,189 @@
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace Azure { namespace Core { namespace Cryptography {
|
||||
namespace {
|
||||
|
||||
#if defined(AZ_PLATFORM_WINDOWS)
|
||||
|
||||
namespace _detail {
|
||||
struct AlgorithmProviderInstance
|
||||
{
|
||||
BCRYPT_ALG_HANDLE Handle;
|
||||
std::size_t ContextSize;
|
||||
std::size_t HashLength;
|
||||
class Md5AlgorithmProvider {
|
||||
private:
|
||||
// Make sure the initial state of status is non-successful
|
||||
NTSTATUS m_status = 0;
|
||||
|
||||
AlgorithmProviderInstance()
|
||||
{
|
||||
NTSTATUS status = BCryptOpenAlgorithmProvider(&Handle, BCRYPT_MD5_ALGORITHM, nullptr, 0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptOpenAlgorithmProvider failed");
|
||||
}
|
||||
DWORD objectLength = 0;
|
||||
DWORD dataLength = 0;
|
||||
status = BCryptGetProperty(
|
||||
Handle,
|
||||
BCRYPT_OBJECT_LENGTH,
|
||||
reinterpret_cast<PBYTE>(&objectLength),
|
||||
sizeof(objectLength),
|
||||
&dataLength,
|
||||
0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptGetProperty failed");
|
||||
}
|
||||
ContextSize = objectLength;
|
||||
DWORD hashLength = 0;
|
||||
status = BCryptGetProperty(
|
||||
Handle,
|
||||
BCRYPT_HASH_LENGTH,
|
||||
reinterpret_cast<PBYTE>(&hashLength),
|
||||
sizeof(hashLength),
|
||||
&dataLength,
|
||||
0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptGetProperty failed");
|
||||
}
|
||||
HashLength = hashLength;
|
||||
}
|
||||
public:
|
||||
BCRYPT_ALG_HANDLE Handle;
|
||||
std::size_t ContextSize;
|
||||
std::size_t HashLength;
|
||||
|
||||
~AlgorithmProviderInstance() { BCryptCloseAlgorithmProvider(Handle, 0); }
|
||||
};
|
||||
|
||||
struct Md5HashContext
|
||||
{
|
||||
std::string buffer;
|
||||
BCRYPT_HASH_HANDLE hashHandle = nullptr;
|
||||
std::size_t hashLength = 0;
|
||||
};
|
||||
} // namespace _detail
|
||||
|
||||
Md5Hash::Md5Hash()
|
||||
Md5AlgorithmProvider()
|
||||
{
|
||||
static _detail::AlgorithmProviderInstance AlgorithmProvider{};
|
||||
|
||||
_detail::Md5HashContext* md5Context = new _detail::Md5HashContext;
|
||||
m_md5Context = md5Context;
|
||||
md5Context->buffer.resize(AlgorithmProvider.ContextSize);
|
||||
md5Context->hashLength = AlgorithmProvider.HashLength;
|
||||
|
||||
NTSTATUS status = BCryptCreateHash(
|
||||
AlgorithmProvider.Handle,
|
||||
&md5Context->hashHandle,
|
||||
reinterpret_cast<PUCHAR>(&md5Context->buffer[0]),
|
||||
static_cast<ULONG>(md5Context->buffer.size()),
|
||||
nullptr,
|
||||
0,
|
||||
0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
// open an algorithm handle
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptOpenAlgorithmProvider(&Handle, BCRYPT_MD5_ALGORITHM, nullptr, 0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptCreateHash failed");
|
||||
throw std::runtime_error("BCryptOpenAlgorithmProvider failed with code: " + m_status);
|
||||
}
|
||||
|
||||
// calculate the size of the buffer to hold the hash object
|
||||
DWORD objectLength = 0;
|
||||
DWORD dataLength = 0;
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptGetProperty(
|
||||
Handle,
|
||||
BCRYPT_OBJECT_LENGTH,
|
||||
reinterpret_cast<PBYTE>(&objectLength),
|
||||
sizeof(objectLength),
|
||||
&dataLength,
|
||||
0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptGetProperty failed with code: " + m_status);
|
||||
}
|
||||
|
||||
// calculate the length of the hash
|
||||
ContextSize = objectLength;
|
||||
DWORD hashLength = 0;
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptGetProperty(
|
||||
Handle,
|
||||
BCRYPT_HASH_LENGTH,
|
||||
reinterpret_cast<PBYTE>(&hashLength),
|
||||
sizeof(hashLength),
|
||||
&dataLength,
|
||||
0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptGetProperty failed with code: " + m_status);
|
||||
}
|
||||
HashLength = hashLength;
|
||||
}
|
||||
|
||||
~Md5AlgorithmProvider()
|
||||
{
|
||||
if (Handle)
|
||||
{
|
||||
BCryptCloseAlgorithmProvider(Handle, 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Md5AlgorithmProvider const& GetMD5AlgorithmProvider()
|
||||
{
|
||||
static Md5AlgorithmProvider instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
class Md5BCrypt : public Azure::Core::Cryptography::Hash {
|
||||
private:
|
||||
// Make sure the initial state of status is non-successful
|
||||
NTSTATUS m_status = 0;
|
||||
BCRYPT_HASH_HANDLE m_hashHandle = nullptr;
|
||||
std::size_t m_hashLength = 0;
|
||||
std::string m_buffer;
|
||||
|
||||
void OnAppend(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptHashData(
|
||||
m_hashHandle,
|
||||
reinterpret_cast<PBYTE>(const_cast<uint8_t*>(data)),
|
||||
static_cast<ULONG>(length),
|
||||
0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptHashData failed with code: " + m_status);
|
||||
}
|
||||
}
|
||||
|
||||
Md5Hash::~Md5Hash()
|
||||
{
|
||||
_detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context);
|
||||
BCryptDestroyHash(md5Context->hashHandle);
|
||||
delete md5Context;
|
||||
}
|
||||
|
||||
void Md5Hash::OnAppend(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
_detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context);
|
||||
|
||||
NTSTATUS status = BCryptHashData(
|
||||
md5Context->hashHandle,
|
||||
reinterpret_cast<PBYTE>(const_cast<uint8_t*>(data)),
|
||||
static_cast<ULONG>(length),
|
||||
0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptHashData failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> Md5Hash::OnFinal(const uint8_t* data, std::size_t length)
|
||||
std::vector<uint8_t> OnFinal(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
OnAppend(data, length);
|
||||
_detail::Md5HashContext* md5Context = static_cast<_detail::Md5HashContext*>(m_md5Context);
|
||||
|
||||
std::vector<uint8_t> hash;
|
||||
hash.resize(md5Context->hashLength);
|
||||
NTSTATUS status = BCryptFinishHash(
|
||||
md5Context->hashHandle,
|
||||
reinterpret_cast<PUCHAR>(&hash[0]),
|
||||
static_cast<ULONG>(hash.size()),
|
||||
0);
|
||||
if (!BCRYPT_SUCCESS(status))
|
||||
hash.resize(m_hashLength);
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptFinishHash(
|
||||
m_hashHandle,
|
||||
reinterpret_cast<PUCHAR>(&hash[0]),
|
||||
static_cast<ULONG>(hash.size()),
|
||||
0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptFinishHash failed");
|
||||
throw std::runtime_error("BCryptFinishHash failed with code: " + m_status);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
public:
|
||||
Md5BCrypt()
|
||||
{
|
||||
m_buffer.resize(GetMD5AlgorithmProvider().ContextSize);
|
||||
m_hashLength = GetMD5AlgorithmProvider().HashLength;
|
||||
|
||||
if (!BCRYPT_SUCCESS(
|
||||
m_status = BCryptCreateHash(
|
||||
GetMD5AlgorithmProvider().Handle,
|
||||
&m_hashHandle,
|
||||
reinterpret_cast<PUCHAR>(&m_buffer[0]),
|
||||
static_cast<ULONG>(m_buffer.size()),
|
||||
nullptr,
|
||||
0,
|
||||
0)))
|
||||
{
|
||||
throw std::runtime_error("BCryptCreateHash failed with code: " + m_status);
|
||||
}
|
||||
}
|
||||
|
||||
~Md5BCrypt()
|
||||
{
|
||||
if (m_hashHandle)
|
||||
{
|
||||
BCryptDestroyHash(m_hashHandle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
Azure::Core::Cryptography::Md5Hash::Md5Hash() : m_implementation(std::make_unique<Md5BCrypt>()) {}
|
||||
|
||||
#elif defined(AZ_PLATFORM_POSIX)
|
||||
|
||||
Md5Hash::Md5Hash()
|
||||
class Md5OpenSSL : public Azure::Core::Cryptography::Hash {
|
||||
private:
|
||||
std::unique_ptr<MD5_CTX> m_context;
|
||||
|
||||
void OnAppend(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
MD5_CTX* md5Context = new MD5_CTX;
|
||||
m_md5Context = md5Context;
|
||||
MD5_Init(md5Context);
|
||||
MD5_Update(m_context.get(), data, length);
|
||||
}
|
||||
|
||||
Md5Hash::~Md5Hash()
|
||||
std::vector<uint8_t> OnFinal(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
MD5_CTX* md5Context = static_cast<MD5_CTX*>(m_md5Context);
|
||||
delete md5Context;
|
||||
OnAppend(data, length);
|
||||
unsigned char hash[MD5_DIGEST_LENGTH];
|
||||
MD5_Final(hash, m_context.get());
|
||||
return std::vector<uint8_t>(std::begin(hash), std::end(hash));
|
||||
}
|
||||
|
||||
public:
|
||||
Md5OpenSSL()
|
||||
{
|
||||
m_context = std::make_unique<MD5_CTX>();
|
||||
MD5_Init(m_context.get());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
Azure::Core::Cryptography::Md5Hash::Md5Hash() : m_implementation(std::make_unique<Md5OpenSSL>()) {}
|
||||
#endif
|
||||
|
||||
namespace Azure { namespace Core { namespace Cryptography {
|
||||
Md5Hash::~Md5Hash() {}
|
||||
|
||||
void Md5Hash::OnAppend(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
MD5_CTX* md5Context = static_cast<MD5_CTX*>(m_md5Context);
|
||||
MD5_Update(md5Context, data, length);
|
||||
m_implementation->Append(data, length);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> Md5Hash::OnFinal(const uint8_t* data, std::size_t length)
|
||||
{
|
||||
OnAppend(data, length);
|
||||
MD5_CTX* md5Context = static_cast<MD5_CTX*>(m_md5Context);
|
||||
unsigned char hash[MD5_DIGEST_LENGTH];
|
||||
MD5_Final(hash, md5Context);
|
||||
return std::vector<uint8_t>(std::begin(hash), std::end(hash));
|
||||
return m_implementation->Final(data, length);
|
||||
}
|
||||
|
||||
#endif
|
||||
}}} // namespace Azure::Core::Cryptography
|
||||
|
||||
@ -4,9 +4,11 @@
|
||||
#include <algorithm>
|
||||
#include <azure/core/base64.hpp>
|
||||
#include <azure/core/cryptography/hash.hpp>
|
||||
#include <chrono>
|
||||
#include <gtest/gtest.h>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
using namespace Azure::Core::Cryptography;
|
||||
@ -98,7 +100,7 @@ TEST(Md5Hash, Basic)
|
||||
TEST(Md5Hash, ExpectThrow)
|
||||
{
|
||||
std::string data = "";
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(data.data());
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(data.c_str());
|
||||
Md5Hash instance;
|
||||
|
||||
EXPECT_THROW(instance.Final(nullptr, 1), std::invalid_argument);
|
||||
@ -118,3 +120,32 @@ TEST(Md5Hash, CtorDtor)
|
||||
Md5Hash instance;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Md5Hash, multiThread)
|
||||
{
|
||||
auto hashThreadRoutine = [](int sleepFor) {
|
||||
Md5Hash instance;
|
||||
std::string data = "";
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(data.c_str());
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(sleepFor));
|
||||
|
||||
EXPECT_EQ(
|
||||
Azure::Core::Convert::Base64Encode(instance.Final(ptr, data.length())),
|
||||
"1B2M2Y8AsgTpgAmY7PhCfg==");
|
||||
};
|
||||
|
||||
constexpr static int size = 100;
|
||||
std::vector<std::thread> pool;
|
||||
|
||||
// Make 100 threads run after a little sleep
|
||||
// Each created thread will wait from 0 to 3 milliseconds to start to make threads overlap
|
||||
for (int counter = 0; counter < size; counter++)
|
||||
{
|
||||
pool.emplace_back(std::thread(hashThreadRoutine, counter % 4));
|
||||
}
|
||||
for (int counter = 0; counter < size; counter++)
|
||||
{
|
||||
pool[counter].join();
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user