From 9770fb77dc77a24cd08942a3711b44d4c9c6b249 Mon Sep 17 00:00:00 2001 From: Larry Osterman Date: Wed, 6 Nov 2024 12:03:08 -0800 Subject: [PATCH] Implement mTLS support in WinHTTP transport. (#6131) * Very preliminary mTLS implementation * Tests for TLS client certificate * Tested mTLS functionality * Added changelog entry; updated PCCERT_CONTEXT using declaration to be more succinct. --- .vscode/cspell.json | 2 + sdk/core/azure-core/CHANGELOG.md | 3 + .../azure/core/http/win_http_transport.hpp | 50 +- .../private/win_http_transport_impl.hpp | 74 + .../src/http/winhttp/win_http_request.hpp | 5 +- .../src/http/winhttp/win_http_transport.cpp | 1848 ++++++++++------- .../test/ut/transport_adapter_base_test.cpp | 12 +- 7 files changed, 1224 insertions(+), 770 deletions(-) create mode 100644 sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp diff --git a/.vscode/cspell.json b/.vscode/cspell.json index a5a4f8f93..9d2845cc1 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -207,6 +207,8 @@ "opentelemetry", "Osterman", "otel", + "PCCERT", + "PCERT", "PBYTE", "pdbs", "phoebusm", diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 2e75653a1..54f66c6d0 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +- Added mTLS support to WinHTTP transport. + - To enable mTLS, first create an appropriate Windows `PCCERT_CONTEXT` object and set the `TlsClientCertificate` field in `WinHttpTransportOptions` to that certificate before creating the `WinHttpTransport` object. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp index 8f2e3e75f..d131ef01f 100644 --- a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +// cspell: words PCCERT + /** * @file * @brief #Azure::Core::Http::HttpTransport implementation via WinHTTP. @@ -20,10 +22,16 @@ #include #include +/** + * @brief Declaration of a Windows PCCERT_CONTEXT structure from the Windows SDK. + */ +using PCCERT_CONTEXT = const struct _CERT_CONTEXT*; + namespace Azure { namespace Core { namespace Http { namespace _detail { + class WinHttpTransportImpl; class WinHttpRequest; - } + } // namespace _detail /** * @brief Sets the WinHTTP session and connection options used to customize the behavior of the @@ -91,6 +99,12 @@ namespace Azure { namespace Core { namespace Http { * the server. */ std::vector ExpectedTlsRootCertificates; + + /** + * @brief TLS Client Certificate Context, used when the TLS Server requests mTLS client + * authentication. + */ + PCCERT_CONTEXT TlsClientCertificate{nullptr}; }; /** @@ -99,23 +113,16 @@ namespace Azure { namespace Core { namespace Http { */ class WinHttpTransport : public HttpTransport { private: - WinHttpTransportOptions m_options; - // m_sessionhandle is const to ensure immutability. - const Azure::Core::_internal::UniqueHandle m_sessionHandle; + std::unique_ptr<_detail::WinHttpTransportImpl> m_impl; - Azure::Core::_internal::UniqueHandle CreateSessionHandle(); - Azure::Core::_internal::UniqueHandle CreateConnectionHandle( - Azure::Core::Url const& url, - Azure::Core::Context const& context); - - std::unique_ptr<_detail::WinHttpRequest> CreateRequestHandle( - Azure::Core::_internal::UniqueHandle const& connectionHandle, - Azure::Core::Url const& url, - Azure::Core::Http::HttpMethod const& method); - - // Callback to allow a derived transport to extract the request handle. Used for WebSocket - // transports. - virtual void OnUpgradedConnection(std::unique_ptr<_detail::WinHttpRequest> const&){}; + protected: + /** @brief Callback to allow a derived transport to extract the request handle. Used for + * WebSocket transports. + * + * @param request - Request which contains the WinHttp request handle. + */ + virtual void OnUpgradedConnection( + std::unique_ptr<_detail::WinHttpRequest> const& request) const; public: /** @@ -125,11 +132,6 @@ namespace Azure { namespace Core { namespace Http { */ WinHttpTransport(WinHttpTransportOptions const& options = WinHttpTransportOptions()); - /** - * @brief Constructs `%WinHttpTransport`. - * - * @param options Optional parameter to override the default settings. - */ /** * @brief Constructs `%WinHttpTransport` object based on common Azure HTTP Transport Options * @@ -148,6 +150,10 @@ namespace Azure { namespace Core { namespace Http { // and virtual or protected and // non-virtual"](http://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c35-a-base-class-destructor-should-be-either-public-and-virtual-or-protected-and-non-virtual) virtual ~WinHttpTransport(); + + // @cond + friend _detail::WinHttpTransportImpl; + // @endcond }; }}} // namespace Azure::Core::Http diff --git a/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp b/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp new file mode 100644 index 000000000..e0da5951c --- /dev/null +++ b/sdk/core/azure-core/src/http/winhttp/private/win_http_transport_impl.hpp @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include + +namespace Azure { namespace Core { namespace Http { namespace _detail { + class WinHttpTransportImpl : HttpTransport { + private: + WinHttpTransport const* m_parent; + + WinHttpTransportOptions m_options; + // m_sessionhandle is const to ensure immutability. + const Azure::Core::_internal::UniqueHandle m_sessionHandle; + wil::unique_cert_context m_tlsClientCertificate; + + Azure::Core::_internal::UniqueHandle CreateSessionHandle(); + Azure::Core::_internal::UniqueHandle CreateConnectionHandle( + Azure::Core::Url const& url, + Azure::Core::Context const& context); + + std::unique_ptr<_detail::WinHttpRequest> CreateRequestHandle( + Azure::Core::_internal::UniqueHandle const& connectionHandle, + Azure::Core::Url const& url, + Azure::Core::Http::HttpMethod const& method); + + // Callback to allow a derived transport to extract the request handle. Used for WebSocket + // transports. + virtual void OnUpgradedConnection(std::unique_ptr<_detail::WinHttpRequest> const& request) + { + m_parent->OnUpgradedConnection(request); + }; + + public: + /** + * @brief Constructs `%WinHttpTransport`. + * + * @param options Optional parameter to override the default settings. + */ + WinHttpTransportImpl( + WinHttpTransport const* parent, + WinHttpTransportOptions const& options = WinHttpTransportOptions()); + + /** + * @brief Constructs `%WinHttpTransport`. + * + * @param options Optional parameter to override the default settings. + */ + /** + * @brief Constructs `%WinHttpTransport` object based on common Azure HTTP Transport Options + * + */ + WinHttpTransportImpl( + WinHttpTransport const* parent, + Azure::Core::Http::Policies::TransportOptions const& options); + + /** + * @brief Implements the HTTP transport interface to send an HTTP Request and produce an + * HTTP RawResponse. + * + */ + virtual std::unique_ptr Send(Request& request, Context const& context) override; + + // See also: + // [Core Guidelines C.35: "A base class destructor should be either public + // and virtual or protected and + // non-virtual"](http://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c35-a-base-class-destructor-should-be-either-public-and-virtual-or-protected-and-non-virtual) + virtual ~WinHttpTransportImpl(); + }; +}}}} // namespace Azure::Core::Http::_detail diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp b/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp index e84fb1083..41dd27f12 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_request.hpp @@ -28,9 +28,10 @@ #pragma warning(disable : 6553) #pragma warning(disable : 6387) // An argument in result_macros.h may be '0', for the function // 'GetProcAddress'. +#include + #include #pragma warning(pop) -#include #include namespace Azure { namespace Core { namespace Http { namespace _detail { @@ -157,6 +158,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { Azure::Core::_internal::UniqueHandle m_requestHandle; std::unique_ptr m_httpAction; std::vector m_expectedTlsRootCertificates; + wil::unique_cert_context m_tlsClientCertificate; /* * Adds the specified trusted certificates to the specified certificate store. @@ -176,6 +178,7 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { Azure::Core::_internal::UniqueHandle const& connectionHandle, Azure::Core::Url const& url, Azure::Core::Http::HttpMethod const& method, + PCCERT_CONTEXT tlsClientCertificate, WinHttpTransportOptions const& options); ~WinHttpRequest(); diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp index 89e350f12..342cc3531 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -// cspell:words HCERTIFICATECHAIN PCCERT CCERT HCERTCHAINENGINE HCERTSTORE lpsz REFERER +// cspell:words HCERTIFICATECHAIN PCCERT CCERT HCERTCHAINENGINE HCERTSTORE lpsz REFERER hcryptprov +// cspell: words ncrypt hcryptkey NCRYPT ncrypt hcryptkey #include "azure/core/base64.hpp" #include "azure/core/diagnostics/logger.hpp" @@ -13,6 +14,7 @@ #include "azure/core/http/win_http_transport.hpp" #include "win_http_request.hpp" #endif +#include "private/win_http_transport_impl.hpp" #if !defined(WIN32_LEAN_AND_MEAN) #define WIN32_LEAN_AND_MEAN @@ -33,6 +35,7 @@ // 'GetProcAddress'. #include // definitions for wil::unique_cert_chain_context and other RAII type wrappers for Windows types. #pragma warning(pop) +#include "private/win_http_transport_impl.hpp" #include #include @@ -461,9 +464,48 @@ std::string InternetStatusToString(DWORD internetStatus) APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_SETTINGS_READ_COMPLETE); return rv; } + +std::string InternetStatusInformationToString(DWORD internetStatus) +{ + std::string rv; + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_CERT_REV_FAILED); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CERT); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_CERT_REVOKED); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CA); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_CERT_CN_INVALID); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_CERT_DATE_INVALID); + APPEND_ENUM_STRING(WINHTTP_CALLBACK_STATUS_FLAG_SECURITY_CHANNEL_ERROR); + return rv; +} + #undef APPEND_ENUM_STRING } // namespace +namespace Azure { namespace Core { namespace Http { + + WinHttpTransport::WinHttpTransport(WinHttpTransportOptions const& options) + : m_impl{std::make_unique<_detail::WinHttpTransportImpl>(this, options)} + { + } + + WinHttpTransport::WinHttpTransport(Azure::Core::Http::Policies::TransportOptions const& options) + : m_impl{std::make_unique<_detail::WinHttpTransportImpl>(this, options)} + { + } + + std::unique_ptr WinHttpTransport::Send(Request& request, Context const& context) + { + return m_impl->Send(request, context); + } + + WinHttpTransport::~WinHttpTransport() = default; + + void WinHttpTransport::OnUpgradedConnection(std::unique_ptr<_detail::WinHttpRequest> const&) const + { + } + +}}} // namespace Azure::Core::Http + namespace Azure { namespace Core { namespace Http { namespace _detail { bool WinHttpAction::RegisterWinHttpStatusCallback( @@ -552,9 +594,9 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { { if (m_expectedStatus != WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING) { - // Note that the order of scope_exit and lock is important - this ensures that scope_exit is - // destroyed *after* lock is destroyed, ensuring that the event is not set to the signalled - // state before the lock is released. + // Note that the order of scope_exit and lock is important - this ensures that scope_exit + // is destroyed *after* lock is destroyed, ensuring that the event is not set to the + // signalled state before the lock is released. auto scope_exit{m_actionCompleteEvent.SetEvent_scope_exit()}; std::unique_lock lock(m_actionCompleteMutex); m_stowedErrorInformation = stowedErrorInformation; @@ -667,7 +709,10 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { + InternetStatusToString(internetStatus) + ")"); if (internetStatus == WINHTTP_CALLBACK_STATUS_SECURE_FAILURE) { - Log::Write(Logger::Level::Error, "Security failure. :("); + DWORD securityFlags = *reinterpret_cast(statusInformation); + Log::Stream(Logger::Level::Error) + << "Security failure. :(" << std::hex << securityFlags << ") (" + << InternetStatusInformationToString(securityFlags) << ")"; } else if (internetStatus == WINHTTP_CALLBACK_STATUS_REQUEST_ERROR) { @@ -758,777 +803,1088 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { } } } -}}}} // namespace Azure::Core::Http::_detail -Azure::Core::_internal::UniqueHandle WinHttpTransport::CreateSessionHandle() -{ - // Use WinHttpOpen to obtain a session handle. - // The dwFlags is set to 0 - all WinHTTP functions are performed synchronously. - Azure::Core::_internal::UniqueHandle sessionHandle(WinHttpOpen( - NULL, // Do not use a fallback user-agent string, and only rely on the header within the - // request itself. - // If the customer asks for it, enable use of the system default HTTP proxy. - (m_options.EnableSystemDefaultProxy ? WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY - : WINHTTP_ACCESS_TYPE_NO_PROXY), - WINHTTP_NO_PROXY_NAME, - WINHTTP_NO_PROXY_BYPASS, - WINHTTP_FLAG_ASYNC)); // All requests on this session are performed asynchronously. - - if (!sessionHandle) + Azure::Core::_internal::UniqueHandle WinHttpTransportImpl::CreateSessionHandle() { - // Errors include: - // ERROR_WINHTTP_INTERNAL_ERROR - // ERROR_NOT_ENOUGH_MEMORY - GetErrorAndThrow("Error while getting a session handle."); - } + // Use WinHttpOpen to obtain a session handle. + // The dwFlags is set to 0 - all WinHTTP functions are performed synchronously. + Azure::Core::_internal::UniqueHandle sessionHandle(WinHttpOpen( + NULL, // Do not use a fallback user-agent string, and only rely on the header within the + // request itself. + // If the customer asks for it, enable use of the system default HTTP proxy. + (m_options.EnableSystemDefaultProxy ? WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY + : WINHTTP_ACCESS_TYPE_NO_PROXY), + WINHTTP_NO_PROXY_NAME, + WINHTTP_NO_PROXY_BYPASS, + WINHTTP_FLAG_ASYNC)); // All requests on this session are performed asynchronously. - // These options are only available starting from Windows 10 Version 2004, starting 06/09/2020. - // These are primarily round trip time (RTT) performance optimizations, and hence if they don't - // get set successfully, we shouldn't fail the request and continue as if the options don't exist. - // Therefore, we just ignore the error and move on. + if (!sessionHandle) + { + // Errors include: + // ERROR_WINHTTP_INTERNAL_ERROR + // ERROR_NOT_ENOUGH_MEMORY + GetErrorAndThrow("Error while getting a session handle."); + } - // TCP_FAST_OPEN has a bug when the DNS resolution fails which can result - // in a leak. Until that issue is fixed we've disable this option. + // These options are only available starting from Windows 10 Version 2004, starting + // 06/09/2020. These are primarily round trip time (RTT) performance optimizations, and + // hence if they don't get set successfully, we shouldn't fail the request and continue as + // if the options don't exist. Therefore, we just ignore the error and move on. + + // TCP_FAST_OPEN has a bug when the DNS resolution fails which can result + // in a leak. Until that issue is fixed we've disable this option. #if defined(WINHTTP_OPTION_TCP_FAST_OPEN) && FALSE - BOOL tcp_fast_open = TRUE; - WinHttpSetOption( - sessionHandle.get(), WINHTTP_OPTION_TCP_FAST_OPEN, &tcp_fast_open, sizeof(tcp_fast_open)); + BOOL tcp_fast_open = TRUE; + WinHttpSetOption( + sessionHandle.get(), WINHTTP_OPTION_TCP_FAST_OPEN, &tcp_fast_open, sizeof(tcp_fast_open)); #endif #ifdef WINHTTP_OPTION_TLS_FALSE_START - BOOL tls_false_start = TRUE; - WinHttpSetOption( - sessionHandle.get(), - WINHTTP_OPTION_TLS_FALSE_START, - &tls_false_start, - sizeof(tls_false_start)); + BOOL tls_false_start = TRUE; + WinHttpSetOption( + sessionHandle.get(), + WINHTTP_OPTION_TLS_FALSE_START, + &tls_false_start, + sizeof(tls_false_start)); #endif - // Enforce TLS version 1.2 or 1.3 (if available). - auto tlsOption = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; + // Enforce TLS version 1.2 or 1.3 (if available). + auto tlsOption = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; #if defined(WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3) - tlsOption |= WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3; + tlsOption |= WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3; #endif - if (!WinHttpSetOption( - sessionHandle.get(), WINHTTP_OPTION_SECURE_PROTOCOLS, &tlsOption, sizeof(tlsOption))) - { -#if defined(WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3) - // If TLS 1.3 is not available, try to set TLS 1.2 only. - tlsOption = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; if (!WinHttpSetOption( sessionHandle.get(), WINHTTP_OPTION_SECURE_PROTOCOLS, &tlsOption, sizeof(tlsOption))) { -#endif - GetErrorAndThrow("Error while enforcing TLS version for connection request."); #if defined(WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3) - } -#endif - } - - return sessionHandle; -} - -namespace { -WinHttpTransportOptions WinHttpTransportOptionsFromTransportOptions( - Azure::Core::Http::Policies::TransportOptions const& transportOptions) -{ - WinHttpTransportOptions httpOptions; - if (transportOptions.HttpProxy.HasValue()) - { - // WinHTTP proxy strings are semicolon separated elements, each of which - // has the following format: - // ([=]["://"][":"]) - std::string proxyString; - proxyString = "http=" + transportOptions.HttpProxy.Value(); - proxyString += ";"; - proxyString += "https=" + transportOptions.HttpProxy.Value(); - httpOptions.ProxyInformation = proxyString; - } - httpOptions.ProxyUserName = transportOptions.ProxyUserName; - httpOptions.ProxyPassword = transportOptions.ProxyPassword; - // Note that WinHTTP accepts a set of root certificates, even though transportOptions only - // specifies a single one. - if (!transportOptions.ExpectedTlsRootCertificate.empty()) - { - httpOptions.ExpectedTlsRootCertificates.push_back(transportOptions.ExpectedTlsRootCertificate); - } - if (transportOptions.EnableCertificateRevocationListCheck) - { - httpOptions.EnableCertificateRevocationListCheck = true; - } - // If you specify an expected TLS root certificate, you also need to enable ignoring unknown - // CAs. - if (!transportOptions.ExpectedTlsRootCertificate.empty()) - { - httpOptions.IgnoreUnknownCertificateAuthority = true; - } - - if (transportOptions.DisableTlsCertificateValidation) - { - httpOptions.IgnoreUnknownCertificateAuthority = true; - httpOptions.IgnoreInvalidCertificateCommonName = true; - } - - return httpOptions; -} -} // namespace - -WinHttpTransport::WinHttpTransport(WinHttpTransportOptions const& options) - : m_options(options), m_sessionHandle(CreateSessionHandle()) -{ -} - -WinHttpTransport::WinHttpTransport( - Azure::Core::Http::Policies::TransportOptions const& transportOptions) - : WinHttpTransport(WinHttpTransportOptionsFromTransportOptions(transportOptions)) -{ -} - -WinHttpTransport::~WinHttpTransport() = default; - -Azure::Core::_internal::UniqueHandle WinHttpTransport::CreateConnectionHandle( - Azure::Core::Url const& url, - Azure::Core::Context const& context) -{ - // If port is 0, i.e. INTERNET_DEFAULT_PORT, it uses port 80 for HTTP and port 443 for HTTPS. - uint16_t port = url.GetPort(); - - // Before doing any work, check to make sure that the context hasn't already been cancelled. - context.ThrowIfCancelled(); - - // Specify an HTTP server. - // This function always operates synchronously. - Azure::Core::_internal::UniqueHandle rv(WinHttpConnect( - m_sessionHandle.get(), - StringToWideString(url.GetHost()).c_str(), - port == 0 ? INTERNET_DEFAULT_PORT : port, - 0)); - - if (!rv) - { - // Errors include: - // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE - // ERROR_WINHTTP_INTERNAL_ERROR - // ERROR_WINHTTP_INVALID_URL - // ERROR_WINHTTP_OPERATION_CANCELLED - // ERROR_WINHTTP_UNRECOGNIZED_SCHEME - // ERROR_WINHTTP_SHUTDOWN - // ERROR_NOT_ENOUGH_MEMORY - GetErrorAndThrow("Error while getting a connection handle."); - } - return rv; -} - -void _detail::WinHttpRequest::EnableWebSocketsSupport() -{ -#pragma warning(push) - // warning C6387: _Param_(3) could be '0'. -#pragma warning(disable : 6387) - if (!WinHttpSetOption(m_requestHandle.get(), WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, nullptr, 0)) -#pragma warning(pop) - { - GetErrorAndThrow("Error while Enabling WebSocket upgrade."); - } -} - -_detail::WinHttpRequest::WinHttpRequest( - Azure::Core::_internal::UniqueHandle const& connectionHandle, - Azure::Core::Url const& url, - Azure::Core::Http::HttpMethod const& method, - WinHttpTransportOptions const& options) - : m_expectedTlsRootCertificates(options.ExpectedTlsRootCertificates) -{ - const std::string& path = url.GetRelativeUrl(); - HttpMethod requestMethod = method; - bool const requestSecureHttp( - !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( - url.GetScheme(), HttpScheme) - && !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( - url.GetScheme(), WebSocketScheme)); - - // Create an HTTP request handle. - m_requestHandle.reset(WinHttpOpenRequest( - connectionHandle.get(), - HttpMethodToWideString(requestMethod).c_str(), - path.empty() ? NULL : StringToWideString(path).c_str(), // Name of the target resource of - // the specified HTTP verb - NULL, // Use HTTP/1.1 - WINHTTP_NO_REFERER, - WINHTTP_DEFAULT_ACCEPT_TYPES, // No media types are accepted by the client - requestSecureHttp ? WINHTTP_FLAG_SECURE : 0)); // Uses secure transaction semantics (SSL/TLS) - if (!m_requestHandle) - { - // Errors include: - // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE - // ERROR_WINHTTP_INTERNAL_ERROR - // ERROR_WINHTTP_INVALID_URL - // ERROR_WINHTTP_OPERATION_CANCELLED - // ERROR_WINHTTP_UNRECOGNIZED_SCHEME - // ERROR_NOT_ENOUGH_MEMORY - GetErrorAndThrow("Error while getting a request handle."); - } - - if (requestSecureHttp) - { - // If the service requests TLS client certificates, we want to let the WinHTTP APIs know that - // it's ok to initiate the request without a client certificate. - // - // Note: If/When TLS client certificate support is added to the pipeline, this line may need to - // be revisited. - if (!WinHttpSetOption( - m_requestHandle.get(), - WINHTTP_OPTION_CLIENT_CERT_CONTEXT, - WINHTTP_NO_CLIENT_CERT_CONTEXT, - 0)) - { - GetErrorAndThrow("Error while setting client cert context to ignore."); - } - } - - if (!options.ProxyInformation.empty()) - { - WINHTTP_PROXY_INFO proxyInfo{}; - std::wstring proxyWide{StringToWideString(options.ProxyInformation)}; - proxyInfo.dwAccessType = WINHTTP_ACCESS_TYPE_NAMED_PROXY; - proxyInfo.lpszProxy = const_cast(proxyWide.c_str()); - proxyInfo.lpszProxyBypass = WINHTTP_NO_PROXY_BYPASS; - if (!WinHttpSetOption( - m_requestHandle.get(), WINHTTP_OPTION_PROXY, &proxyInfo, sizeof(proxyInfo))) - { - GetErrorAndThrow("Error while setting Proxy information."); - } - } - if (options.ProxyUserName.HasValue() || options.ProxyPassword.HasValue()) - { - if (!WinHttpSetCredentials( - m_requestHandle.get(), - WINHTTP_AUTH_TARGET_PROXY, - WINHTTP_AUTH_SCHEME_BASIC, - StringToWideString(options.ProxyUserName.Value()).c_str(), - StringToWideString(options.ProxyPassword.Value()).c_str(), - 0)) - { - GetErrorAndThrow("Error while setting Proxy credentials."); - } - } - - if (options.IgnoreUnknownCertificateAuthority || !options.ExpectedTlsRootCertificates.empty()) - { - auto option = SECURITY_FLAG_IGNORE_UNKNOWN_CA; - if (!WinHttpSetOption( - m_requestHandle.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) - { - GetErrorAndThrow("Error while setting ignore unknown server certificate."); - } - } - - if (options.IgnoreInvalidCertificateCommonName) - { - auto option = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; - if (!WinHttpSetOption( - m_requestHandle.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) - { - GetErrorAndThrow("Error while setting ignore invalid certificate common name."); - } - } - - if (options.EnableCertificateRevocationListCheck) - { - DWORD value = WINHTTP_ENABLE_SSL_REVOCATION; - if (!WinHttpSetOption( - m_requestHandle.get(), WINHTTP_OPTION_ENABLE_FEATURE, &value, sizeof(value))) - { - GetErrorAndThrow("Error while enabling CRL validation."); - } - } - - DWORD disableRedirects = WINHTTP_DISABLE_REDIRECTS; - if (!WinHttpSetOption( - m_requestHandle.get(), - WINHTTP_OPTION_DISABLE_FEATURE, - &disableRedirects, - sizeof(disableRedirects))) - { - GetErrorAndThrow("Error while disabling redirects."); - } - - // Set the callback function to be called whenever the state of the request handle changes. - m_httpAction = std::make_unique<_detail::WinHttpAction>(this); - - if (!m_httpAction->RegisterWinHttpStatusCallback(m_requestHandle)) - { - GetErrorAndThrow("Error while setting up the status callback."); - } -} - -/* - * Destructor for WinHTTP request. Closes the request handle. - */ -_detail::WinHttpRequest::~WinHttpRequest() -{ - if (!m_requestHandleClosed) - { - Log::Write( - Logger::Level::Informational, - "WinHttpRequest::~WinHttpRequest. Closing handle synchronously."); - - // Close the outstanding request handle, waiting until the HANDLE_CLOSING status is received. - if (!m_httpAction->WaitForAction( - [this]() { - auto requestHandle = m_requestHandle.release(); - if (!WinHttpCloseHandle(requestHandle)) - { - Log::Write( - Logger::Level::Error, - "Error closing WinHTTP handle: " + GetErrorMessage(GetLastError())); - } - }, - - WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING, - Azure::Core::Context{})) - { - Log::Write(Logger::Level::Error, "Error while closing the request handle."); - } - Log::Write(Logger::Level::Informational, "WinHttpRequest::~WinHttpRequest. Handle closed."); - } -} - -std::unique_ptr<_detail::WinHttpRequest> WinHttpTransport::CreateRequestHandle( - Azure::Core::_internal::UniqueHandle const& connectionHandle, - Azure::Core::Url const& url, - Azure::Core::Http::HttpMethod const& method) -{ - auto request{std::make_unique<_detail::WinHttpRequest>(connectionHandle, url, method, m_options)}; - // If we are supporting WebSockets, then let WinHTTP know that it should - // prepare to upgrade the HttpRequest to a WebSocket. - if (HasWebSocketSupport()) - { - request->EnableWebSocketsSupport(); - } - return request; -} - -// For PUT/POST requests, send additional data using WinHttpWriteData. -void _detail::WinHttpRequest::Upload( - Azure::Core::Http::Request& request, - Azure::Core::Context const& context) -{ - auto streamBody = request.GetBodyStream(); - int64_t streamLength = streamBody->Length(); - - // Consider using `MaximumUploadChunkSize` here, after some perf measurements - size_t uploadChunkSize = DefaultUploadChunkSize; - if (streamLength < MaximumUploadChunkSize) - { - uploadChunkSize = static_cast(streamLength); - } - auto unique_buffer = std::make_unique(uploadChunkSize); - - while (true) - { - size_t rawRequestLen = streamBody->Read(unique_buffer.get(), uploadChunkSize, context); - if (rawRequestLen == 0) - { - break; - } - - DWORD dwBytesWritten = 0; - - if (!m_httpAction->WaitForAction( - [&]() { // Write data to the server. - if (!WinHttpWriteData( - m_requestHandle.get(), - unique_buffer.get(), - static_cast(rawRequestLen), - &dwBytesWritten)) - { - GetErrorAndThrow("Error while uploading/sending data."); - } - }, - WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, - context)) - - { - GetErrorAndThrow("Error sending HTTP request asynchronously", m_httpAction->GetStowedError()); - } - } -} - -void _detail::WinHttpRequest::SendRequest( - Azure::Core::Http::Request& request, - Azure::Core::Context const& context) -{ - std::wstring encodedHeaders; - int encodedHeadersLength = 0; - - auto requestHeaders = request.GetHeaders(); - if (requestHeaders.size() != 0) - { - // The encodedHeaders will be null-terminated and the length is calculated. - encodedHeadersLength = -1; - std::string requestHeaderString = GetHeadersAsString(request); - requestHeaderString.append("\0"); - - encodedHeaders = StringToWideString(requestHeaderString); - } - - int64_t streamLength = request.GetBodyStream()->Length(); - - try - { - if (!m_httpAction->WaitForAction( - [&]() { - { - // Send a request. - // NB: DO NOT CHANGE THE TYPE OF THE CONTEXT PARAMETER WITHOUT UPDATING THE - // HttpAction::StatusCallback method. - if (!WinHttpSendRequest( - m_requestHandle.get(), - requestHeaders.size() == 0 ? WINHTTP_NO_ADDITIONAL_HEADERS - : encodedHeaders.c_str(), - encodedHeadersLength, - WINHTTP_NO_REQUEST_DATA, - 0, - streamLength > 0 ? static_cast(streamLength) : 0, - reinterpret_cast( - m_httpAction - .get()))) // Context for WinHTTP status callbacks for this request. - { - // Errors include: - // ERROR_WINHTTP_CANNOT_CONNECT - // ERROR_WINHTTP_CLIENT_AUTH_CERT_NEEDED - // ERROR_WINHTTP_CONNECTION_ERROR - // ERROR_WINHTTP_INCORRECT_HANDLE_STATE - // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE - // ERROR_WINHTTP_INTERNAL_ERROR - // ERROR_WINHTTP_INVALID_URL - // ERROR_WINHTTP_LOGIN_FAILURE - // ERROR_WINHTTP_NAME_NOT_RESOLVED - // ERROR_WINHTTP_OPERATION_CANCELLED - // ERROR_WINHTTP_RESPONSE_DRAIN_OVERFLOW - // ERROR_WINHTTP_SECURE_FAILURE - // ERROR_WINHTTP_SHUTDOWN - // ERROR_WINHTTP_TIMEOUT - // ERROR_WINHTTP_UNRECOGNIZED_SCHEME - // ERROR_NOT_ENOUGH_MEMORY - // ERROR_INVALID_PARAMETER - // ERROR_WINHTTP_RESEND_REQUEST - GetErrorAndThrow("Error while sending a request."); - } - } - }, - WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE, - context)) - { - GetErrorAndThrow( - "Error while waiting for a send to complete.", m_httpAction->GetStowedError()); - } - - // Chunked transfer encoding is not supported and the content length needs to be known up - // front. - if (streamLength == -1) - { - throw Azure::Core::Http::TransportException( - "When uploading data, the body stream must have a known length."); - } - - if (streamLength > 0) - { - Upload(request, context); - } - } - catch (TransportException const&) - { - // If there was a TLS validation error, then we will have closed the request handle - // during the TLS validation callback. So if an exception was thrown, if we force closed the - // request handle, clear the handle in the requestHandle to prevent a double free. - if (m_requestHandleClosed) - { - m_requestHandle.release(); - } - throw; - } -} - -void _detail::WinHttpRequest::ReceiveResponse(Azure::Core::Context const& context) -{ - // Wait to receive the response to the HTTP request initiated by WinHttpSendRequest. - // When WinHttpReceiveResponse completes successfully, the status code and response headers have - // been received. - if (!m_httpAction->WaitForAction( - [this]() { - if (!WinHttpReceiveResponse(m_requestHandle.get(), NULL)) - { - // Errors include: - // ERROR_WINHTTP_CANNOT_CONNECT - // ERROR_WINHTTP_CHUNKED_ENCODING_HEADER_SIZE_OVERFLOW - // ERROR_WINHTTP_CLIENT_AUTH_CERT_NEEDED - // ... - // ERROR_WINHTTP_TIMEOUT - // ERROR_WINHTTP_UNRECOGNIZED_SCHEME - // ERROR_NOT_ENOUGH_MEMORY - GetErrorAndThrow("Error while receiving a response."); - } - }, - WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE, - context)) - { - GetErrorAndThrow("Error while receiving a response.", m_httpAction->GetStowedError()); - } -} - -int64_t _detail::WinHttpRequest::GetContentLength( - HttpMethod requestMethod, - HttpStatusCode responseStatusCode) -{ - DWORD dwContentLength = 0; - DWORD dwSize = sizeof(dwContentLength); - - // For Head request, set the length of body response to 0. - // Response will give us content-length as if we were not doing Head saying what would be the - // length of the body. However, server won't send any body. - // For NoContent status code, also need to set contentLength to 0. - int64_t contentLength = 0; - - // Get the content length as a number. - if (requestMethod != HttpMethod::Head && responseStatusCode != HttpStatusCode::NoContent) - { - if (!WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER, - WINHTTP_HEADER_NAME_BY_INDEX, - &dwContentLength, - &dwSize, - WINHTTP_NO_HEADER_INDEX)) - { - contentLength = -1; - } - else - { - contentLength = static_cast(dwContentLength); - } - } - - return contentLength; -} - -std::unique_ptr _detail::WinHttpRequest::SendRequestAndGetResponse( - HttpMethod requestMethod) -{ - // First, use WinHttpQueryHeaders to obtain the size of the buffer. - // The call is expected to fail since no destination buffer is provided. - DWORD sizeOfHeaders = 0; - if (WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_RAW_HEADERS, - WINHTTP_HEADER_NAME_BY_INDEX, - NULL, - &sizeOfHeaders, - WINHTTP_NO_HEADER_INDEX)) - { - // WinHttpQueryHeaders was expected to fail. - throw Azure::Core::Http::TransportException("Error while querying response headers."); - } - - { - DWORD error = GetLastError(); - if (error != ERROR_INSUFFICIENT_BUFFER) - { - GetErrorAndThrow("Error while querying response headers.", error); - } - } - - // Allocate memory for the buffer. - std::vector outputBuffer(sizeOfHeaders / sizeof(WCHAR), 0); - - // Now, use WinHttpQueryHeaders to retrieve all the headers. - // Each header is terminated by "\0". An additional "\0" terminates the list of headers. - if (!WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_RAW_HEADERS, - WINHTTP_HEADER_NAME_BY_INDEX, - outputBuffer.data(), - &sizeOfHeaders, - WINHTTP_NO_HEADER_INDEX)) - { - GetErrorAndThrow("Error while querying response headers."); - } - - auto start = outputBuffer.begin(); - auto last = start + sizeOfHeaders / sizeof(WCHAR); - auto statusLineEnd = std::find(start, last, '\0'); - start = statusLineEnd + 1; // start of headers - std::string responseHeaders = WideStringToString(std::wstring(start, last)); - - DWORD sizeOfHttp = sizeOfHeaders; - - // Get the HTTP version. - if (!WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_VERSION, - WINHTTP_HEADER_NAME_BY_INDEX, - outputBuffer.data(), - &sizeOfHttp, - WINHTTP_NO_HEADER_INDEX)) - { - GetErrorAndThrow("Error while querying response headers."); - } - - start = outputBuffer.begin(); - // Assuming ASCII here is OK since the input is expected to be an HTTP version string. - std::string httpVersion = WideStringToStringASCII(start, start + sizeOfHttp / sizeof(WCHAR)); - - uint16_t majorVersion = 0; - uint16_t minorVersion = 0; - ParseHttpVersion(httpVersion, &majorVersion, &minorVersion); - - DWORD statusCode = 0; - DWORD dwSize = sizeof(statusCode); - - // Get the status code as a number. - if (!WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, - WINHTTP_HEADER_NAME_BY_INDEX, - &statusCode, - &dwSize, - WINHTTP_NO_HEADER_INDEX)) - { - GetErrorAndThrow("Error while querying response headers."); - } - - HttpStatusCode httpStatusCode = static_cast(statusCode); - - // Get the optional reason phrase. - std::string reasonPhrase; - DWORD sizeOfReasonPhrase = sizeOfHeaders; - - // HTTP/2 does not support reason phrase, refer to - // https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.4. - if (majorVersion == 1) - { - if (WinHttpQueryHeaders( - m_requestHandle.get(), - WINHTTP_QUERY_STATUS_TEXT, - WINHTTP_HEADER_NAME_BY_INDEX, - outputBuffer.data(), - &sizeOfReasonPhrase, - WINHTTP_NO_HEADER_INDEX)) - { - // even with HTTP/1.1, we cannot assume that reason phrase is set since it is optional - // according to https://www.rfc-editor.org/rfc/rfc2616.html#section-6.1.1. - if (sizeOfReasonPhrase > 0) + // If TLS 1.3 is not available, try to set TLS 1.2 only. + tlsOption = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; + if (!WinHttpSetOption( + sessionHandle.get(), WINHTTP_OPTION_SECURE_PROTOCOLS, &tlsOption, sizeof(tlsOption))) { - start = outputBuffer.begin(); - reasonPhrase - = WideStringToString(std::wstring(start, start + sizeOfReasonPhrase / sizeof(WCHAR))); +#endif + GetErrorAndThrow("Error while enforcing TLS version for connection request."); +#if defined(WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3) + } +#endif + } + + return sessionHandle; + } + + namespace { +#if 0 + /******************** Begin Potential Throwaway Code *******/ + // These functions are captured because they may be required when further iterations on the mTLS + // functionality is needed and it would be a shame to lose them. + struct CertName : public CERT_NAME_BLOB + { + CertName(std::string const& x500dn) : CERT_NAME_BLOB{0} + { + // Determine the size needed to encode the buffer. + CertStrToName( + X509_ASN_ENCODING, + x500dn.c_str(), + CERT_X500_NAME_STR, + nullptr, + pbData, + &cbData, + nullptr); + pbData = new BYTE[cbData]; + if (!CertStrToName( + X509_ASN_ENCODING, + x500dn.c_str(), + CERT_X500_NAME_STR, + nullptr, + pbData, + &cbData, + nullptr)) + { + throw std::runtime_error("Failed to convert string to name blob"); + } + } + ~CertName() { delete pbData; } + }; + + struct CryptDataBlob : public CRYPT_DATA_BLOB + { + CryptDataBlob() : CRYPT_DATA_BLOB{0} {} + ~CryptDataBlob() { delete pbData; } + }; + + wil::unique_hcryptprov CreateCryptoProvider() + { + wil::unique_hcryptprov hCryptProv; + if (!CryptAcquireContext( + hCryptProv.addressof(), NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) + { + GetErrorAndThrow("Failed to acquire crypto context: "); + } + + return hCryptProv; + } + + wil::unique_ncrypt_key DecodePemPrivateKey( + HCRYPTPROV cryptoProvider, + std::string const& privateKey) + { + DWORD privateKeyLength = 0; + if (!CryptStringToBinary( + privateKey.c_str(), + 0, + CRYPT_STRING_BASE64HEADER, + nullptr, + &privateKeyLength, + nullptr, + nullptr)) + { + GetErrorAndThrow("Failed to get size of convert private key in binary: "); + } + std::vector privateKeyBinary(privateKeyLength); + if (!CryptStringToBinary( + privateKey.c_str(), + 0, + CRYPT_STRING_BASE64HEADER, + privateKeyBinary.data(), + &privateKeyLength, + nullptr, + nullptr)) + { + GetErrorAndThrow("Failed to convert private key to binary: "); + } + + wil::unique_ncrypt_key nCryptKey; + if (!NCryptImportKey( + cryptoProvider, + 0, + NCRYPT_PKCS8_PRIVATE_KEY_BLOB, + nullptr, + nCryptKey.addressof(), + privateKeyBinary.data(), + privateKeyLength, + NULL)) + { + GetErrorAndThrow("Failed to add private key to store: "); + } + return nCryptKey; + } + + wil::unique_cert_context DecodePemCertificate(std::string const& certificatePem) + { + DWORD certificateLength = 0; + if (!CryptStringToBinary( + certificatePem.c_str(), + 0, + CRYPT_STRING_BASE64HEADER, + nullptr, + &certificateLength, + nullptr, + nullptr)) + { + } + std::vector certificateBinary(certificateLength); + if (!CryptStringToBinary( + certificatePem.c_str(), + 0, + CRYPT_STRING_BASE64HEADER, + certificateBinary.data(), + &certificateLength, + nullptr, + nullptr)) + { + } + + wil::unique_cert_context certContext{CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + certificateBinary.data(), + static_cast(certificateBinary.size()))}; + if (!certContext) + { + GetErrorAndThrow("Failed to create certificate context"); + } + + return certContext; + } + + void ExportCertificateToPfxFile(PCCERT_CONTEXT certificate, const std::string pfxFileName) + { + wil::unique_hcertstore store{ + CertOpenStore(CERT_STORE_PROV_MEMORY, 0, 0, CERT_STORE_CREATE_NEW_FLAG, nullptr)}; + if (!store) + { + throw std::runtime_error("Failed to create certificate store"); + } + if (!CertAddCertificateContextToStore(store.get(), certificate, CERT_STORE_ADD_NEW, nullptr)) + { + throw std::runtime_error("Failed to add certificate to store"); + } + CryptDataBlob pfxBlob; + + if (!PFXExportCertStoreEx( + store.get(), + &pfxBlob, + nullptr, + nullptr, + (PKCS12_INCLUDE_EXTENDED_PROPERTIES | REPORT_NOT_ABLE_TO_EXPORT_PRIVATE_KEY + | REPORT_NO_PRIVATE_KEY | EXPORT_PRIVATE_KEYS))) + { + throw std::runtime_error("Failed to export certificate"); + } + + pfxBlob.pbData = new BYTE[pfxBlob.cbData]; + + if (!PFXExportCertStoreEx( + store.get(), + &pfxBlob, + nullptr, + nullptr, + (PKCS12_INCLUDE_EXTENDED_PROPERTIES | REPORT_NOT_ABLE_TO_EXPORT_PRIVATE_KEY + | REPORT_NO_PRIVATE_KEY | EXPORT_PRIVATE_KEYS))) + { + throw std::runtime_error("Failed to export certificate"); + } + + { + wil::unique_hfile pfxFile{ + CreateFileA(pfxFileName.c_str(), GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, 0, nullptr)}; + + if (!WriteFile(pfxFile.get(), pfxBlob.pbData, pfxBlob.cbData, nullptr, nullptr)) + { + throw std::runtime_error("Failed to write pfx file"); + } + } + } + + wil::unique_hcertstore ImportPfxCertificateStore(std::string const& fileName) + { + CryptDataBlob pfxBlob; + { + wil::unique_hfile pfxFile{ + CreateFileA(fileName.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr)}; + if (pfxFile.get() == INVALID_HANDLE_VALUE) + { + throw std::runtime_error("Failed to open pfx file"); + } + BY_HANDLE_FILE_INFORMATION fileInfo; + if (!GetFileInformationByHandle(pfxFile.get(), &fileInfo)) + { + throw std::runtime_error("Failed to get file information"); + } + pfxBlob.pbData = new BYTE[fileInfo.nFileSizeLow]; + DWORD bytesRead; + if (!ReadFile(pfxFile.get(), pfxBlob.pbData, fileInfo.nFileSizeLow, &bytesRead, nullptr)) + { + throw std::runtime_error("Failed to read pfx file"); + } + pfxBlob.cbData = bytesRead; + } + wil::unique_hcertstore store{ + PFXImportCertStore(&pfxBlob, nullptr, CRYPT_EXPORTABLE | CRYPT_MACHINE_KEYSET)}; + if (!store) + { + throw std::runtime_error("Failed to import pfx file"); + } + return store; + } + +#if 0 + + // Open a PFX store and load the certificate in the store into a TLS context. + // + // auto cryptoProvider = CreateCryptoProvider(); + // auto privateKey = DecodePemPrivateKey(cryptoProvider.get(), TlsClientPrivateKey); + // auto certificate = DecodePemCertificate(TlsClientCertificate); + + wil::unique_hcertstore store{ImportPfxCertificateStore("client_cert.pfx")}; + + wil::unique_cert_context certificate{CertEnumCertificatesInStore(store.get(), nullptr)}; + + //{ + // CERT_KEY_CONTEXT keyContext; + // keyContext.cbSize = sizeof(keyContext); + // keyContext.hCryptProv = cryptoProvider.get(); + // keyContext.hNCryptKey = privateKey.get(); + // keyContext.dwKeySpec = CERT_NCRYPT_KEY_SPEC | AT_SIGNATURE | AT_KEYEXCHANGE; + // ; + // if (!CertSetCertificateContextProperty( + // certificate.get(), + // CERT_KEY_CONTEXT_PROP_ID, + // CERT_STORE_NO_CRYPT_RELEASE_FLAG, + // &keyContext)) + // { + // throw std::runtime_error("Failed to set private key to certificate context"); + // } + //} + + // ExportCertificateToPfxFile(certificate.get(), "test.pfx"); + + wil::unique_hcryptkey key2; + if (!CryptAcquireCertificatePrivateKey( + certificate.get(), 0, nullptr, key2.addressof(), nullptr, nullptr)) + { + GTEST_LOG_(INFO) << "Failed to retrieve private key from certificate: " << GetLastError(); + } + + options.TlsClientCertificate = certificate.get(); +#endif + /******************** End Potential Throwaway Code *******/ +#endif + + WinHttpTransportOptions WinHttpTransportOptionsFromTransportOptions( + Azure::Core::Http::Policies::TransportOptions const& transportOptions) + { + WinHttpTransportOptions httpOptions; + if (transportOptions.HttpProxy.HasValue()) + { + // WinHTTP proxy strings are semicolon separated elements, each of which + // has the following format: + // ([=]["://"][":"]) + std::string proxyString; + proxyString = "http=" + transportOptions.HttpProxy.Value(); + proxyString += ";"; + proxyString += "https=" + transportOptions.HttpProxy.Value(); + httpOptions.ProxyInformation = proxyString; + } + httpOptions.ProxyUserName = transportOptions.ProxyUserName; + httpOptions.ProxyPassword = transportOptions.ProxyPassword; + // Note that WinHTTP accepts a set of root certificates, even though transportOptions only + // specifies a single one. + if (!transportOptions.ExpectedTlsRootCertificate.empty()) + { + httpOptions.ExpectedTlsRootCertificates.push_back( + transportOptions.ExpectedTlsRootCertificate); + } + if (transportOptions.EnableCertificateRevocationListCheck) + { + httpOptions.EnableCertificateRevocationListCheck = true; + } + // If you specify an expected TLS root certificate, you also need to enable ignoring + // unknown CAs. + if (!transportOptions.ExpectedTlsRootCertificate.empty()) + { + httpOptions.IgnoreUnknownCertificateAuthority = true; + } + + if (transportOptions.DisableTlsCertificateValidation) + { + httpOptions.IgnoreUnknownCertificateAuthority = true; + httpOptions.IgnoreInvalidCertificateCommonName = true; + } + + return httpOptions; + } + } // namespace + + WinHttpTransportImpl::WinHttpTransportImpl( + WinHttpTransport const* parent, + WinHttpTransportOptions const& options) + : m_parent{parent}, m_options(options), m_sessionHandle(CreateSessionHandle()) + { + if (options.TlsClientCertificate) + { + // Preserve the input client certificate for later use. + m_tlsClientCertificate.reset(CertDuplicateCertificateContext(options.TlsClientCertificate)); + if (!m_tlsClientCertificate) + { + GetErrorAndThrow("Error while duplicating client certificate context."); + } + // Erase the TLS client certificate in the m_options member because it cannot be relied upon + // from this point on. + m_options.TlsClientCertificate = nullptr; + } + } + + WinHttpTransportImpl::WinHttpTransportImpl( + WinHttpTransport const* parent, + Azure::Core::Http::Policies::TransportOptions const& transportOptions) + : WinHttpTransportImpl(parent, WinHttpTransportOptionsFromTransportOptions(transportOptions)) + { + } + + WinHttpTransportImpl::~WinHttpTransportImpl() = default; + + Azure::Core::_internal::UniqueHandle WinHttpTransportImpl::CreateConnectionHandle( + Azure::Core::Url const& url, + Azure::Core::Context const& context) + { + // If port is 0, i.e. INTERNET_DEFAULT_PORT, it uses port 80 for HTTP and port 443 for + // HTTPS. + uint16_t port = url.GetPort(); + + // Before doing any work, check to make sure that the context hasn't already been cancelled. + context.ThrowIfCancelled(); + + // Specify an HTTP server. + // This function always operates synchronously. + Azure::Core::_internal::UniqueHandle rv(WinHttpConnect( + m_sessionHandle.get(), + StringToWideString(url.GetHost()).c_str(), + port == 0 ? INTERNET_DEFAULT_PORT : port, + 0)); + + if (!rv) + { + // Errors include: + // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE + // ERROR_WINHTTP_INTERNAL_ERROR + // ERROR_WINHTTP_INVALID_URL + // ERROR_WINHTTP_OPERATION_CANCELLED + // ERROR_WINHTTP_UNRECOGNIZED_SCHEME + // ERROR_WINHTTP_SHUTDOWN + // ERROR_NOT_ENOUGH_MEMORY + GetErrorAndThrow("Error while getting a connection handle."); + } + return rv; + } + + void WinHttpRequest::EnableWebSocketsSupport() + { +#pragma warning(push) + // warning C6387: _Param_(3) could be '0'. +#pragma warning(disable : 6387) + if (!WinHttpSetOption(m_requestHandle.get(), WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, nullptr, 0)) +#pragma warning(pop) + { + GetErrorAndThrow("Error while Enabling WebSocket upgrade."); + } + } + + /** @brief Construct a new WinHttpRequest object. + * + * @param connectionHandle The connection handle to use for the request. + * @param url The URL to request. + * @param method The HTTP method to use for the request. + * @param tlsClientCertificate The client certificate to use for the request. + * @param options The transport options to use for the request. + * + * @remark Note that we *cannot* use the TlsClientCertificate field in the options passed into + * this function because the creator of the associated WinHttpTransport object may have freed the + * memory backing that object after constructing the WinHttpTransport object. Therefore, we must + * use the tlsClientCertificate saved in the WinHttpTransport object instead. + * + */ + WinHttpRequest::WinHttpRequest( + Azure::Core::_internal::UniqueHandle const& connectionHandle, + Azure::Core::Url const& url, + Azure::Core::Http::HttpMethod const& method, + PCCERT_CONTEXT tlsClientCertificate, + WinHttpTransportOptions const& options) + : m_expectedTlsRootCertificates(options.ExpectedTlsRootCertificates), + m_tlsClientCertificate(CertDuplicateCertificateContext(tlsClientCertificate)) + { + const std::string& path = url.GetRelativeUrl(); + HttpMethod requestMethod = method; + bool const requestSecureHttp( + !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( + url.GetScheme(), HttpScheme) + && !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( + url.GetScheme(), WebSocketScheme)); + + // Create an HTTP request handle. + m_requestHandle.reset(WinHttpOpenRequest( + connectionHandle.get(), + HttpMethodToWideString(requestMethod).c_str(), + path.empty() ? NULL : StringToWideString(path).c_str(), // Name of the target resource + // of the specified HTTP verb + NULL, // Use HTTP/1.1 + WINHTTP_NO_REFERER, + WINHTTP_DEFAULT_ACCEPT_TYPES, // No media types are accepted by the client + requestSecureHttp ? WINHTTP_FLAG_SECURE + : 0)); // Uses secure transaction semantics (SSL/TLS) + if (!m_requestHandle) + { + // Errors include: + // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE + // ERROR_WINHTTP_INTERNAL_ERROR + // ERROR_WINHTTP_INVALID_URL + // ERROR_WINHTTP_OPERATION_CANCELLED + // ERROR_WINHTTP_UNRECOGNIZED_SCHEME + // ERROR_NOT_ENOUGH_MEMORY + GetErrorAndThrow("Error while getting a request handle."); + } + + if (requestSecureHttp) + { + if (!m_tlsClientCertificate) + { + // If the service requests TLS client certificates, we want to let the WinHTTP APIs know + // that it's ok to initiate the request without a client certificate. + // + // Note: If/When TLS client certificate support is added to the pipeline, this line may + // need to be revisited. + if (!WinHttpSetOption( + m_requestHandle.get(), + WINHTTP_OPTION_CLIENT_CERT_CONTEXT, + WINHTTP_NO_CLIENT_CERT_CONTEXT, + 0)) + { + GetErrorAndThrow("Error while setting client cert context to ignore."); + } + } + } + + if (!options.ProxyInformation.empty()) + { + WINHTTP_PROXY_INFO proxyInfo{}; + std::wstring proxyWide{StringToWideString(options.ProxyInformation)}; + proxyInfo.dwAccessType = WINHTTP_ACCESS_TYPE_NAMED_PROXY; + proxyInfo.lpszProxy = const_cast(proxyWide.c_str()); + proxyInfo.lpszProxyBypass = WINHTTP_NO_PROXY_BYPASS; + if (!WinHttpSetOption( + m_requestHandle.get(), WINHTTP_OPTION_PROXY, &proxyInfo, sizeof(proxyInfo))) + { + GetErrorAndThrow("Error while setting Proxy information."); + } + } + if (options.ProxyUserName.HasValue() || options.ProxyPassword.HasValue()) + { + if (!WinHttpSetCredentials( + m_requestHandle.get(), + WINHTTP_AUTH_TARGET_PROXY, + WINHTTP_AUTH_SCHEME_BASIC, + StringToWideString(options.ProxyUserName.Value()).c_str(), + StringToWideString(options.ProxyPassword.Value()).c_str(), + 0)) + { + GetErrorAndThrow("Error while setting Proxy credentials."); + } + } + + if (options.IgnoreUnknownCertificateAuthority || !options.ExpectedTlsRootCertificates.empty()) + { + auto option = SECURITY_FLAG_IGNORE_UNKNOWN_CA; + if (!WinHttpSetOption( + m_requestHandle.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) + { + GetErrorAndThrow("Error while setting ignore unknown server certificate."); + } + } + + if (options.IgnoreInvalidCertificateCommonName) + { + auto option = SECURITY_FLAG_IGNORE_CERT_CN_INVALID; + if (!WinHttpSetOption( + m_requestHandle.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) + { + GetErrorAndThrow("Error while setting ignore invalid certificate common name."); + } + } + + if (options.EnableCertificateRevocationListCheck) + { + DWORD value = WINHTTP_ENABLE_SSL_REVOCATION; + if (!WinHttpSetOption( + m_requestHandle.get(), WINHTTP_OPTION_ENABLE_FEATURE, &value, sizeof(value))) + { + GetErrorAndThrow("Error while enabling CRL validation."); + } + } + + DWORD disableRedirects = WINHTTP_DISABLE_REDIRECTS; + if (!WinHttpSetOption( + m_requestHandle.get(), + WINHTTP_OPTION_DISABLE_FEATURE, + &disableRedirects, + sizeof(disableRedirects))) + { + GetErrorAndThrow("Error while disabling redirects."); + } + + // Set the callback function to be called whenever the state of the request handle changes. + m_httpAction = std::make_unique<_detail::WinHttpAction>(this); + + if (!m_httpAction->RegisterWinHttpStatusCallback(m_requestHandle)) + { + GetErrorAndThrow("Error while setting up the status callback."); + } + } + + /* + * Destructor for WinHTTP request. Closes the request handle. + */ + WinHttpRequest::~WinHttpRequest() + { + if (!m_requestHandleClosed) + { + Log::Write( + Logger::Level::Informational, + "WinHttpRequest::~WinHttpRequest. Closing handle synchronously."); + + // Close the outstanding request handle, waiting until the HANDLE_CLOSING status is + // received. + if (!m_httpAction->WaitForAction( + [this]() { + auto requestHandle = m_requestHandle.release(); + if (!WinHttpCloseHandle(requestHandle)) + { + Log::Write( + Logger::Level::Error, + "Error closing WinHTTP handle: " + GetErrorMessage(GetLastError())); + } + }, + + WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING, + Azure::Core::Context{})) + { + Log::Write(Logger::Level::Error, "Error while closing the request handle."); + } + Log::Write(Logger::Level::Informational, "WinHttpRequest::~WinHttpRequest. Handle closed."); + } + } + + std::unique_ptr WinHttpTransportImpl::CreateRequestHandle( + Azure::Core::_internal::UniqueHandle const& connectionHandle, + Azure::Core::Url const& url, + Azure::Core::Http::HttpMethod const& method) + { + auto request{std::make_unique( + connectionHandle, url, method, m_tlsClientCertificate.get(), m_options)}; + // If we are supporting WebSockets, then let WinHTTP know that it should + // prepare to upgrade the HttpRequest to a WebSocket. + if (HasWebSocketSupport()) + { + request->EnableWebSocketsSupport(); + } + return request; + } + + // For PUT/POST requests, send additional data using WinHttpWriteData. + void WinHttpRequest::Upload( + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) + { + auto streamBody = request.GetBodyStream(); + int64_t streamLength = streamBody->Length(); + + // Consider using `MaximumUploadChunkSize` here, after some perf measurements + size_t uploadChunkSize = DefaultUploadChunkSize; + if (streamLength < MaximumUploadChunkSize) + { + uploadChunkSize = static_cast(streamLength); + } + auto unique_buffer = std::make_unique(uploadChunkSize); + + while (true) + { + size_t rawRequestLen = streamBody->Read(unique_buffer.get(), uploadChunkSize, context); + if (rawRequestLen == 0) + { + break; + } + + DWORD dwBytesWritten = 0; + + if (!m_httpAction->WaitForAction( + [&]() { // Write data to the server. + if (!WinHttpWriteData( + m_requestHandle.get(), + unique_buffer.get(), + static_cast(rawRequestLen), + &dwBytesWritten)) + { + GetErrorAndThrow("Error while uploading/sending data."); + } + }, + WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE, + context)) + + { + GetErrorAndThrow( + "Error sending HTTP request asynchronously", m_httpAction->GetStowedError()); } } } - // Allocate the instance of the response on the heap with a shared ptr so this memory gets - // delegated outside the transport and will be eventually released. - auto rawResponse - = std::make_unique(majorVersion, minorVersion, httpStatusCode, reasonPhrase); - - SetHeaders(responseHeaders, rawResponse); - - return rawResponse; -} - -std::unique_ptr WinHttpTransport::Send(Request& request, Context const& context) -{ - Azure::Core::_internal::UniqueHandle connectionHandle - = CreateConnectionHandle(request.GetUrl(), context); - std::unique_ptr<_detail::WinHttpRequest> requestHandle( - CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod())); - - requestHandle->SendRequest(request, context); - requestHandle->ReceiveResponse(context); - - auto rawResponse{requestHandle->SendRequestAndGetResponse(request.GetMethod())}; - if (rawResponse && HasWebSocketSupport() - && (rawResponse->GetStatusCode() == HttpStatusCode::SwitchingProtocols)) + void WinHttpRequest::SendRequest( + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) { - OnUpgradedConnection(requestHandle); - } - else - { - int64_t contentLength - = requestHandle->GetContentLength(request.GetMethod(), rawResponse->GetStatusCode()); + std::wstring encodedHeaders; + int encodedHeadersLength = 0; - rawResponse->SetBodyStream( - std::make_unique<_detail::WinHttpStream>(requestHandle, contentLength)); - } - return rawResponse; -} + auto requestHeaders = request.GetHeaders(); + if (requestHeaders.size() != 0) + { + // The encodedHeaders will be null-terminated and the length is calculated. + encodedHeadersLength = -1; + std::string requestHeaderString = GetHeadersAsString(request); + requestHeaderString.append("\0"); -size_t _detail::WinHttpRequest::ReadData( - uint8_t* buffer, - size_t count, - Azure::Core::Context const& context) -{ - DWORD numberOfBytesRead = 0; - if (!m_httpAction->WaitForAction( - [&]() { - if (!WinHttpReadData( - this->m_requestHandle.get(), - (LPVOID)(buffer), - static_cast(count), - &numberOfBytesRead)) - { - // Errors include: - // ERROR_WINHTTP_CONNECTION_ERROR - // ERROR_WINHTTP_INCORRECT_HANDLE_STATE - // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE - // ERROR_WINHTTP_INTERNAL_ERROR - // ERROR_WINHTTP_OPERATION_CANCELLED - // ERROR_WINHTTP_RESPONSE_DRAIN_OVERFLOW - // ERROR_WINHTTP_TIMEOUT - // ERROR_NOT_ENOUGH_MEMORY + encodedHeaders = StringToWideString(requestHeaderString); + } - DWORD error = GetLastError(); - throw Azure::Core::Http::TransportException( - "Error while reading available data from the wire. Error Code: " - + std::to_string(error) + "."); - } - Log::Write( - Logger::Level::Verbose, - "Read Data read from wire. Size: " + std::to_string(numberOfBytesRead) + "."); - }, - WINHTTP_CALLBACK_STATUS_READ_COMPLETE, - context)) - { - GetErrorAndThrow("Error sending HTTP request asynchronously", m_httpAction->GetStowedError()); - } - if (numberOfBytesRead == 0) - { - numberOfBytesRead = m_httpAction->GetBytesAvailable(); + int64_t streamLength = request.GetBodyStream()->Length(); + + if (m_tlsClientCertificate) + { + Log::Stream(Logger::Level::Verbose) + << "Client certificate needed, providing before request.." << std::endl; + if (!WinHttpSetOption( + m_requestHandle.get(), + WINHTTP_OPTION_CLIENT_CERT_CONTEXT, + reinterpret_cast(const_cast(m_tlsClientCertificate.get())), + sizeof(CERT_CONTEXT))) + { + GetErrorAndThrow("Error setting client certificate."); + } + } + + try + { + if (!m_httpAction->WaitForAction( + [&]() { + { + // Send a request. + // NB: DO NOT CHANGE THE TYPE OF THE CONTEXT PARAMETER WITHOUT UPDATING THE + // HttpAction::StatusCallback method. + if (!WinHttpSendRequest( + m_requestHandle.get(), + requestHeaders.size() == 0 ? WINHTTP_NO_ADDITIONAL_HEADERS + : encodedHeaders.c_str(), + encodedHeadersLength, + WINHTTP_NO_REQUEST_DATA, + 0, + streamLength > 0 ? static_cast(streamLength) : 0, + reinterpret_cast( + m_httpAction.get()))) // Context for WinHTTP status callbacks for + // this request. + { + // Errors include: + // ERROR_WINHTTP_CANNOT_CONNECT + // ERROR_WINHTTP_CLIENT_AUTH_CERT_NEEDED + // ERROR_WINHTTP_CONNECTION_ERROR + // ERROR_WINHTTP_INCORRECT_HANDLE_STATE + // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE + // ERROR_WINHTTP_INTERNAL_ERROR + // ERROR_WINHTTP_INVALID_URL + // ERROR_WINHTTP_LOGIN_FAILURE + // ERROR_WINHTTP_NAME_NOT_RESOLVED + // ERROR_WINHTTP_OPERATION_CANCELLED + // ERROR_WINHTTP_RESPONSE_DRAIN_OVERFLOW + // ERROR_WINHTTP_SECURE_FAILURE + // ERROR_WINHTTP_SHUTDOWN + // ERROR_WINHTTP_TIMEOUT + // ERROR_WINHTTP_UNRECOGNIZED_SCHEME + // ERROR_NOT_ENOUGH_MEMORY + // ERROR_INVALID_PARAMETER + // ERROR_WINHTTP_RESEND_REQUEST + GetErrorAndThrow("Error while sending a request."); + } + } + }, + WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE, + context)) + { + GetErrorAndThrow( + "Error while waiting for a send to complete.", m_httpAction->GetStowedError()); + } + + // Chunked transfer encoding is not supported and the content length needs to be known up + // front. + if (streamLength == -1) + { + throw Azure::Core::Http::TransportException( + "When uploading data, the body stream must have a known length."); + } + + if (streamLength > 0) + { + Upload(request, context); + } + } + catch (TransportException const&) + { + // If there was a TLS validation error, then we will have closed the request handle + // during the TLS validation callback. So if an exception was thrown, if we force closed + // the request handle, clear the handle in the requestHandle to prevent a double free. + if (m_requestHandleClosed) + { + m_requestHandle.release(); + } + throw; + } } - Log::Write( - Logger::Level::Verbose, "ReadData returned size: " + std::to_string(numberOfBytesRead) + "."); - - return numberOfBytesRead; -} - -// Read the response from the sent request. -size_t _detail::WinHttpStream::OnRead(uint8_t* buffer, size_t count, Context const& context) -{ - if (count == 0 || this->m_isEOF) + void WinHttpRequest::ReceiveResponse(Azure::Core::Context const& context) { - return 0; + // Wait to receive the response to the HTTP request initiated by WinHttpSendRequest. + // When WinHttpReceiveResponse completes successfully, the status code and response headers + // have been received. + if (!m_httpAction->WaitForAction( + [this]() { + if (!WinHttpReceiveResponse(m_requestHandle.get(), NULL)) + { + // Errors include: + // ERROR_WINHTTP_CANNOT_CONNECT + // ERROR_WINHTTP_CHUNKED_ENCODING_HEADER_SIZE_OVERFLOW + // ERROR_WINHTTP_CLIENT_AUTH_CERT_NEEDED + // ... + // ERROR_WINHTTP_TIMEOUT + // ERROR_WINHTTP_UNRECOGNIZED_SCHEME + // ERROR_NOT_ENOUGH_MEMORY + GetErrorAndThrow("Error while receiving a response."); + } + }, + WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE, + context)) + { + GetErrorAndThrow("Error while receiving a response.", m_httpAction->GetStowedError()); + } } - size_t numberOfBytesRead = m_requestHandle->ReadData(buffer, count, context); - - this->m_streamTotalRead += numberOfBytesRead; - - if (numberOfBytesRead == 0 - || (this->m_contentLength != -1 && this->m_streamTotalRead == this->m_contentLength)) + int64_t WinHttpRequest::GetContentLength( + HttpMethod requestMethod, + HttpStatusCode responseStatusCode) { - this->m_isEOF = true; + DWORD dwContentLength = 0; + DWORD dwSize = sizeof(dwContentLength); + + // For Head request, set the length of body response to 0. + // Response will give us content-length as if we were not doing Head saying what would be + // the length of the body. However, server won't send any body. For NoContent status code, + // also need to set contentLength to 0. + int64_t contentLength = 0; + + // Get the content length as a number. + if (requestMethod != HttpMethod::Head && responseStatusCode != HttpStatusCode::NoContent) + { + if (!WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, + &dwContentLength, + &dwSize, + WINHTTP_NO_HEADER_INDEX)) + { + contentLength = -1; + } + else + { + contentLength = static_cast(dwContentLength); + } + } + + return contentLength; } - return numberOfBytesRead; -} + + std::unique_ptr WinHttpRequest::SendRequestAndGetResponse(HttpMethod requestMethod) + { + // First, use WinHttpQueryHeaders to obtain the size of the buffer. + // The call is expected to fail since no destination buffer is provided. + DWORD sizeOfHeaders = 0; + if (WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_RAW_HEADERS, + WINHTTP_HEADER_NAME_BY_INDEX, + NULL, + &sizeOfHeaders, + WINHTTP_NO_HEADER_INDEX)) + { + // WinHttpQueryHeaders was expected to fail. + throw Azure::Core::Http::TransportException("Error while querying response headers."); + } + + { + DWORD error = GetLastError(); + if (error != ERROR_INSUFFICIENT_BUFFER) + { + GetErrorAndThrow("Error while querying response headers.", error); + } + } + + // Allocate memory for the buffer. + std::vector outputBuffer(sizeOfHeaders / sizeof(WCHAR), 0); + + // Now, use WinHttpQueryHeaders to retrieve all the headers. + // Each header is terminated by "\0". An additional "\0" terminates the list of headers. + if (!WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_RAW_HEADERS, + WINHTTP_HEADER_NAME_BY_INDEX, + outputBuffer.data(), + &sizeOfHeaders, + WINHTTP_NO_HEADER_INDEX)) + { + GetErrorAndThrow("Error while querying response headers."); + } + + auto start = outputBuffer.begin(); + auto last = start + sizeOfHeaders / sizeof(WCHAR); + auto statusLineEnd = std::find(start, last, '\0'); + start = statusLineEnd + 1; // start of headers + std::string responseHeaders = WideStringToString(std::wstring(start, last)); + + DWORD sizeOfHttp = sizeOfHeaders; + + // Get the HTTP version. + if (!WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_VERSION, + WINHTTP_HEADER_NAME_BY_INDEX, + outputBuffer.data(), + &sizeOfHttp, + WINHTTP_NO_HEADER_INDEX)) + { + GetErrorAndThrow("Error while querying response headers."); + } + + start = outputBuffer.begin(); + // Assuming ASCII here is OK since the input is expected to be an HTTP version string. + std::string httpVersion = WideStringToStringASCII(start, start + sizeOfHttp / sizeof(WCHAR)); + + uint16_t majorVersion = 0; + uint16_t minorVersion = 0; + ParseHttpVersion(httpVersion, &majorVersion, &minorVersion); + + DWORD statusCode = 0; + DWORD dwSize = sizeof(statusCode); + + // Get the status code as a number. + if (!WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, + &statusCode, + &dwSize, + WINHTTP_NO_HEADER_INDEX)) + { + GetErrorAndThrow("Error while querying response headers."); + } + + HttpStatusCode httpStatusCode = static_cast(statusCode); + + // Get the optional reason phrase. + std::string reasonPhrase; + DWORD sizeOfReasonPhrase = sizeOfHeaders; + + // HTTP/2 does not support reason phrase, refer to + // https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.4. + if (majorVersion == 1) + { + if (WinHttpQueryHeaders( + m_requestHandle.get(), + WINHTTP_QUERY_STATUS_TEXT, + WINHTTP_HEADER_NAME_BY_INDEX, + outputBuffer.data(), + &sizeOfReasonPhrase, + WINHTTP_NO_HEADER_INDEX)) + { + // even with HTTP/1.1, we cannot assume that reason phrase is set since it is optional + // according to https://www.rfc-editor.org/rfc/rfc2616.html#section-6.1.1. + if (sizeOfReasonPhrase > 0) + { + start = outputBuffer.begin(); + reasonPhrase + = WideStringToString(std::wstring(start, start + sizeOfReasonPhrase / sizeof(WCHAR))); + } + } + } + + // Allocate the instance of the response on the heap with a shared ptr so this memory gets + // delegated outside the transport and will be eventually released. + auto rawResponse + = std::make_unique(majorVersion, minorVersion, httpStatusCode, reasonPhrase); + + SetHeaders(responseHeaders, rawResponse); + + return rawResponse; + } + + std::unique_ptr WinHttpTransportImpl::Send(Request& request, Context const& context) + { + Azure::Core::_internal::UniqueHandle connectionHandle + = CreateConnectionHandle(request.GetUrl(), context); + std::unique_ptr<_detail::WinHttpRequest> requestHandle( + CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod())); + + requestHandle->SendRequest(request, context); + requestHandle->ReceiveResponse(context); + + auto rawResponse{requestHandle->SendRequestAndGetResponse(request.GetMethod())}; + if (rawResponse && HasWebSocketSupport() + && (rawResponse->GetStatusCode() == HttpStatusCode::SwitchingProtocols)) + { + OnUpgradedConnection(requestHandle); + } + else + { + int64_t contentLength + = requestHandle->GetContentLength(request.GetMethod(), rawResponse->GetStatusCode()); + + rawResponse->SetBodyStream( + std::make_unique<_detail::WinHttpStream>(requestHandle, contentLength)); + } + return rawResponse; + } + + size_t WinHttpRequest::ReadData( + uint8_t* buffer, + size_t count, + Azure::Core::Context const& context) + { + DWORD numberOfBytesRead = 0; + if (!m_httpAction->WaitForAction( + [&]() { + if (!WinHttpReadData( + this->m_requestHandle.get(), + (LPVOID)(buffer), + static_cast(count), + &numberOfBytesRead)) + { + // Errors include: + // ERROR_WINHTTP_CONNECTION_ERROR + // ERROR_WINHTTP_INCORRECT_HANDLE_STATE + // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE + // ERROR_WINHTTP_INTERNAL_ERROR + // ERROR_WINHTTP_OPERATION_CANCELLED + // ERROR_WINHTTP_RESPONSE_DRAIN_OVERFLOW + // ERROR_WINHTTP_TIMEOUT + // ERROR_NOT_ENOUGH_MEMORY + + DWORD error = GetLastError(); + throw Azure::Core::Http::TransportException( + "Error while reading available data from the wire. Error Code: " + + std::to_string(error) + "."); + } + Log::Write( + Logger::Level::Verbose, + "Read Data read from wire. Size: " + std::to_string(numberOfBytesRead) + "."); + }, + WINHTTP_CALLBACK_STATUS_READ_COMPLETE, + context)) + { + GetErrorAndThrow("Error sending HTTP request asynchronously", m_httpAction->GetStowedError()); + } + if (numberOfBytesRead == 0) + { + numberOfBytesRead = m_httpAction->GetBytesAvailable(); + } + + Log::Write( + Logger::Level::Verbose, + "ReadData returned size: " + std::to_string(numberOfBytesRead) + "."); + + return numberOfBytesRead; + } + + // Read the response from the sent request. + size_t WinHttpStream::OnRead(uint8_t* buffer, size_t count, Context const& context) + { + if (count == 0 || this->m_isEOF) + { + return 0; + } + + size_t numberOfBytesRead = m_requestHandle->ReadData(buffer, count, context); + + this->m_streamTotalRead += numberOfBytesRead; + + if (numberOfBytesRead == 0 + || (this->m_contentLength != -1 && this->m_streamTotalRead == this->m_contentLength)) + { + this->m_isEOF = true; + } + return numberOfBytesRead; + } +}}}} // namespace Azure::Core::Http::_detail diff --git a/sdk/core/azure-core/test/ut/transport_adapter_base_test.cpp b/sdk/core/azure-core/test/ut/transport_adapter_base_test.cpp index 4daf33ac8..65d05e8e2 100644 --- a/sdk/core/azure-core/test/ut/transport_adapter_base_test.cpp +++ b/sdk/core/azure-core/test/ut/transport_adapter_base_test.cpp @@ -49,7 +49,17 @@ namespace Azure { namespace Core { namespace Test { auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, host); auto response = m_pipeline->Send(request, Context{}); checkResponseCode(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::NoContent); - auto expectedResponseBodySize = std::stoull(response->GetHeaders().at("content-length")); + std::uint64_t expectedResponseBodySize; + if (response->GetStatusCode() == Azure::Core::Http::HttpStatusCode::NoContent) + { + // http://mt3.google.com/generate_204 returns 204 with no body and thus no content-length + // header + expectedResponseBodySize = 0; + } + else + { + expectedResponseBodySize = std::stoull(response->GetHeaders().at("content-length")); + } CheckBodyFromBuffer(*response, expectedResponseBodySize); }