diff --git a/sdk/storage/azure-storage-blobs/src/blob_client.cpp b/sdk/storage/azure-storage-blobs/src/blob_client.cpp index f172d0d8e..0272b5031 100644 --- a/sdk/storage/azure-storage-blobs/src/blob_client.cpp +++ b/sdk/storage/azure-storage-blobs/src/blob_client.cpp @@ -204,10 +204,15 @@ namespace Azure { namespace Storage { namespace Blobs { { // In case network failure during reading the body const Azure::ETag eTag = downloadResponse.Value.Details.ETag; - - auto retryFunction - = [this, options, eTag](int64_t retryOffset, const Azure::Core::Context& context) - -> std::unique_ptr { + const std::string client_request_id + = downloadResponse.RawResponse->GetHeaders().find(_internal::HttpHeaderClientRequestId) + == downloadResponse.RawResponse->GetHeaders().end() + ? std::string() + : downloadResponse.RawResponse->GetHeaders().at(_internal::HttpHeaderClientRequestId); + auto retryFunction = + [this, options, eTag, client_request_id]( + int64_t retryOffset, + const Azure::Core::Context& context) -> std::unique_ptr { DownloadBlobOptions newOptions = options; newOptions.Range = Core::Http::HttpRange(); newOptions.Range.Value().Offset @@ -217,7 +222,11 @@ namespace Azure { namespace Storage { namespace Blobs { newOptions.Range.Value().Length = options.Range.Value().Length.Value() - retryOffset; } newOptions.AccessConditions.IfMatch = eTag; - return std::move(Download(newOptions, context).Value.BodyStream); + return std::move( + Download( + newOptions, + context.WithValue(_internal::ReliableStreamClientRequestIdKey, client_request_id)) + .Value.BodyStream); }; _internal::ReliableStreamOptions reliableStreamOptions; diff --git a/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/constants.hpp b/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/constants.hpp index 1b7db941e..df0486cf7 100644 --- a/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/constants.hpp +++ b/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/constants.hpp @@ -10,9 +10,11 @@ namespace Azure { namespace Storage { namespace _internal { constexpr static const char* QueueServicePackageName = "storage-queues"; constexpr static const char* HttpQuerySnapshot = "snapshot"; constexpr static const char* HttpQueryVersionId = "versionid"; + constexpr static const char* HttpQueryTimeout = "timeout"; constexpr static const char* StorageScope = "https://storage.azure.com/.default"; constexpr static const char* StorageDefaultAudience = "https://storage.azure.com"; constexpr static const char* HttpHeaderDate = "date"; + constexpr static const char* HttpHeaderXMsDate = "x-ms-date"; constexpr static const char* HttpHeaderXMsVersion = "x-ms-version"; constexpr static const char* HttpHeaderRequestId = "x-ms-request-id"; constexpr static const char* HttpHeaderClientRequestId = "x-ms-client-request-id"; diff --git a/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/reliable_stream.hpp b/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/reliable_stream.hpp index 22c23768a..373f8e600 100644 --- a/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/reliable_stream.hpp +++ b/sdk/storage/azure-storage-common/inc/azure/storage/common/internal/reliable_stream.hpp @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#pragma once + +#include "azure/storage/common/dll_import_export.hpp" + #include #include @@ -10,6 +14,9 @@ namespace Azure { namespace Storage { namespace _internal { + AZ_STORAGE_COMMON_DLLEXPORT extern const Azure::Core::Context::Key + ReliableStreamClientRequestIdKey; + // Options used by reliable stream struct ReliableStreamOptions final { diff --git a/sdk/storage/azure-storage-common/src/reliable_stream.cpp b/sdk/storage/azure-storage-common/src/reliable_stream.cpp index 2a424c9a7..c3f9f8a30 100644 --- a/sdk/storage/azure-storage-common/src/reliable_stream.cpp +++ b/sdk/storage/azure-storage-common/src/reliable_stream.cpp @@ -20,6 +20,7 @@ namespace Azure { namespace Storage { } // namespace _detail namespace _internal { + Azure::Core::Context::Key const ReliableStreamClientRequestIdKey; size_t ReliableStream::OnRead(uint8_t* buffer, size_t count, Context const& context) { diff --git a/sdk/storage/azure-storage-common/src/storage_per_retry_policy.cpp b/sdk/storage/azure-storage-common/src/storage_per_retry_policy.cpp index ac3b74216..e9433752c 100644 --- a/sdk/storage/azure-storage-common/src/storage_per_retry_policy.cpp +++ b/sdk/storage/azure-storage-common/src/storage_per_retry_policy.cpp @@ -3,6 +3,9 @@ #include "azure/storage/common/internal/storage_per_retry_policy.hpp" +#include "azure/storage/common/internal/constants.hpp" +#include "azure/storage/common/internal/reliable_stream.hpp" + #include #include @@ -16,9 +19,6 @@ namespace Azure { namespace Storage { namespace _internal { Core::Http::Policies::NextHttpPolicy nextPolicy, Core::Context const& context) const { - const char* HttpHeaderDate = "Date"; - const char* HttpHeaderXMsDate = "x-ms-date"; - const auto& headers = request.GetHeaders(); if (headers.find(HttpHeaderDate) == headers.end()) { @@ -29,11 +29,10 @@ namespace Azure { namespace Storage { namespace _internal { .ToString(Azure::DateTime::DateFormat::Rfc1123)); } - const char* HttpHeaderTimeout = "timeout"; auto cancelTimepoint = context.GetDeadline(); if (cancelTimepoint == Azure::DateTime::max()) { - request.GetUrl().RemoveQueryParameter(HttpHeaderTimeout); + request.GetUrl().RemoveQueryParameter(HttpQueryTimeout); } else { @@ -43,8 +42,18 @@ namespace Azure { namespace Storage { namespace _internal { .count() : -1; request.GetUrl().AppendQueryParameter( - HttpHeaderTimeout, std::to_string(std::max(numSeconds, int64_t(1)))); + HttpQueryTimeout, std::to_string(std::max(numSeconds, int64_t(1)))); } + + std::string client_request_id; + if (context.TryGetValue(ReliableStreamClientRequestIdKey, client_request_id)) + { + if (!client_request_id.empty()) + { + request.SetHeader(HttpHeaderClientRequestId, client_request_id); + } + } + return nextPolicy.Send(request, context); } diff --git a/sdk/storage/azure-storage-files-shares/src/share_file_client.cpp b/sdk/storage/azure-storage-files-shares/src/share_file_client.cpp index 12ade37bc..3da9a69f4 100644 --- a/sdk/storage/azure-storage-files-shares/src/share_file_client.cpp +++ b/sdk/storage/azure-storage-files-shares/src/share_file_client.cpp @@ -305,10 +305,15 @@ namespace Azure { namespace Storage { namespace Files { namespace Shares { { // In case network failure during reading the body auto eTag = downloadResponse.Value.Details.ETag; - - auto retryFunction - = [this, options, eTag](int64_t retryOffset, const Azure::Core::Context& context) - -> std::unique_ptr { + const std::string client_request_id + = downloadResponse.RawResponse->GetHeaders().find(_internal::HttpHeaderClientRequestId) + == downloadResponse.RawResponse->GetHeaders().end() + ? std::string() + : downloadResponse.RawResponse->GetHeaders().at(_internal::HttpHeaderClientRequestId); + auto retryFunction = + [this, options, eTag, client_request_id]( + int64_t retryOffset, + const Azure::Core::Context& context) -> std::unique_ptr { DownloadFileOptions newOptions = options; newOptions.Range = Core::Http::HttpRange(); newOptions.Range.Value().Offset @@ -318,7 +323,9 @@ namespace Azure { namespace Storage { namespace Files { namespace Shares { newOptions.Range.Value().Length = options.Range.Value().Length.Value() - retryOffset; } - auto newResponse = Download(newOptions, context); + auto newResponse = Download( + newOptions, + context.WithValue(_internal::ReliableStreamClientRequestIdKey, client_request_id)); if (eTag != newResponse.Value.Details.ETag) { throw Azure::Core::RequestFailedException("File was modified in the middle of download.");