feat: Add WebSocket transport implementation
- Extract core WebSocket files from original Larry Osterman implementation
- Add WebSocket headers and implementation to build system
- Update HttpPipeline API usage to current version
- Include CURL and WinHTTP WebSocket adapters
- Add WebSocket test infrastructure
Successfully compiles with current main (azure-core target)
Files added:
- sdk/core/azure-core/inc/azure/core/http/websockets/*
- sdk/core/azure-core/src/http/websockets/*
- sdk/core/azure-core/src/http/curl/curl_websockets.cpp
- sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp
- sdk/core/azure-core/test/ut/websocket_test.cpp
Cherry-picked from: 5cff286c0 (Initial WebSockets implementation)
This commit is contained in:
parent
5eed2ccafd
commit
a090ae75da
@ -43,18 +43,22 @@ if(BUILD_TRANSPORT_CURL)
|
||||
src/http/curl/curl_connection_pool_private.hpp
|
||||
src/http/curl/curl_connection_private.hpp
|
||||
src/http/curl/curl_session_private.hpp
|
||||
src/http/curl/curl_websockets.cpp
|
||||
)
|
||||
SET(CURL_TRANSPORT_ADAPTER_INC
|
||||
inc/azure/core/http/curl_transport.hpp
|
||||
inc/azure/core/http/websockets/curl_websockets_transport.hpp
|
||||
)
|
||||
endif()
|
||||
if(BUILD_TRANSPORT_WINHTTP)
|
||||
SET(WIN_TRANSPORT_ADAPTER_SRC
|
||||
src/http/winhttp/win_http_transport.cpp
|
||||
src/http/winhttp/win_http_request.hpp
|
||||
src/http/winhttp/win_http_websockets.cpp
|
||||
)
|
||||
SET(WIN_TRANSPORT_ADAPTER_INC
|
||||
inc/azure/core/http/win_http_transport.hpp
|
||||
inc/azure/core/http/websockets/win_http_websockets_transport.hpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -80,6 +84,8 @@ set(
|
||||
inc/azure/core/http/policies/policy.hpp
|
||||
inc/azure/core/http/raw_response.hpp
|
||||
inc/azure/core/http/transport.hpp
|
||||
inc/azure/core/http/websockets/websockets.hpp
|
||||
inc/azure/core/http/websockets/websockets_transport.hpp
|
||||
inc/azure/core/internal/client_options.hpp
|
||||
inc/azure/core/internal/contract.hpp
|
||||
inc/azure/core/internal/credentials/authorization_challenge_parser.hpp
|
||||
@ -142,6 +148,8 @@ set(
|
||||
src/http/transport_policy.cpp
|
||||
src/http/url.cpp
|
||||
src/http/user_agent.cpp
|
||||
src/http/websockets/websockets.cpp
|
||||
src/http/websockets/websockets_impl.cpp
|
||||
src/io/body_stream.cpp
|
||||
src/io/random_access_file_body_stream.cpp
|
||||
src/logger.cpp
|
||||
|
||||
@ -0,0 +1,155 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via CURL.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "azure/core/context.hpp"
|
||||
#include "azure/core/http/curl_transport.hpp"
|
||||
#include "azure/core/http/http.hpp"
|
||||
#include "azure/core/http/transport.hpp"
|
||||
#include "azure/core/http/websockets/websockets_transport.hpp"
|
||||
#include <memory>
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
|
||||
struct CurlWebSocketTransportOptions : public Azure::Core::Http::CurlTransportOptions
|
||||
{
|
||||
};
|
||||
/**
|
||||
* @brief Concrete implementation of a WebSocket Transport that uses libcurl.
|
||||
*/
|
||||
class CurlWebSocketTransport : public CurlTransport, public WebSocketTransport {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new CurlWebSocketTransport object.
|
||||
*
|
||||
* @param options Optional parameter to override the default options.
|
||||
*/
|
||||
CurlWebSocketTransport(
|
||||
CurlWebSocketTransportOptions const& options = CurlWebSocketTransportOptions())
|
||||
: CurlTransport(options)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse
|
||||
*
|
||||
* @param request an HTTP Request to be send.
|
||||
* @param context A context to control the request lifetime.
|
||||
*
|
||||
* @return unique ptr to an HTTP RawResponse.
|
||||
*/
|
||||
virtual std::unique_ptr<RawResponse> Send(Request& request, Context const& context) override;
|
||||
|
||||
/**
|
||||
* @brief Indicates if the transport natively supports websockets or not.
|
||||
*
|
||||
* @details For the CURL websocket transport, the transport does NOT support native websockets -
|
||||
* it is the responsibility of the client of the WebSocketTransport to format WebSocket protocol
|
||||
* elements.
|
||||
*/
|
||||
virtual bool HasBuiltInWebSocketSupport() override { return false; }
|
||||
|
||||
/**
|
||||
* @brief Closes the WebSocket handle.
|
||||
*
|
||||
*/
|
||||
virtual void Close() override;
|
||||
|
||||
// Native WebSocket support methods.
|
||||
/**
|
||||
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
* The first param is the close reason, the second is descriptive text.
|
||||
*/
|
||||
virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&)
|
||||
override
|
||||
{
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Retrieve the status of the close socket operation.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
*/
|
||||
NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
|
||||
const Azure::Core::Context&) override
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Send a frame of data to the remote node.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
*/
|
||||
virtual void NativeSendFrame(
|
||||
NativeWebSocketFrameType,
|
||||
std::vector<uint8_t> const&,
|
||||
Azure::Core::Context const&) override
|
||||
{
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Receive a frame of data from the remote node.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
*/
|
||||
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
|
||||
Azure::Core::Context const&) override
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
// Non-Native WebSocket support.
|
||||
/**
|
||||
* @brief This function is used when working with streams to pull more data from the wire.
|
||||
* Function will try to keep pulling data from socket until the buffer is all written or until
|
||||
* there is no more data to get from the socket.
|
||||
*
|
||||
* @param buffer Buffer to fill with data.
|
||||
* @param bufferSize Size of buffer.
|
||||
* @param context Context to control the request lifetime.
|
||||
*
|
||||
* @returns Buffer data received.
|
||||
*
|
||||
*/
|
||||
virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context)
|
||||
override;
|
||||
|
||||
/**
|
||||
* @brief This method will use libcurl socket to write all the bytes from buffer.
|
||||
*
|
||||
* @param buffer Buffer to send.
|
||||
* @param bufferSize Number of bytes to write.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context)
|
||||
override;
|
||||
|
||||
/**
|
||||
* @brief returns true if this transport supports WebSockets, false otherwise.
|
||||
*/
|
||||
bool HasWebSocketSupport() const override { return true; }
|
||||
|
||||
private:
|
||||
// std::unique_ptr cannot be constructed on an incomplete type (CurlNetworkConnection), but
|
||||
// std::shared_ptr can be.
|
||||
std::shared_ptr<Azure::Core::Http::CurlNetworkConnection> m_upgradedConnection;
|
||||
void OnUpgradedConnection(
|
||||
std::unique_ptr<Azure::Core::Http::CurlNetworkConnection>&& upgradedConnection) override;
|
||||
};
|
||||
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
@ -0,0 +1,404 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief Azure Core APIs implementing the WebSocket protocol [RFC 6455]
|
||||
* (https://www.rfc-editor.org/rfc/rfc6455.html).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "azure/core/context.hpp"
|
||||
#include "azure/core/http/http.hpp"
|
||||
#include "azure/core/http/transport.hpp"
|
||||
#include "azure/core/internal/client_options.hpp"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
namespace _detail {
|
||||
class WebSocketImplementation;
|
||||
}
|
||||
namespace _internal {
|
||||
|
||||
enum class WebSocketFrameType : int
|
||||
{
|
||||
Unknown,
|
||||
TextFrameReceived,
|
||||
BinaryFrameReceived,
|
||||
PeerClosedReceived,
|
||||
};
|
||||
|
||||
enum class WebSocketErrorCode : uint16_t
|
||||
{
|
||||
OK = 1000,
|
||||
EndpointDisappearing = 1001,
|
||||
ProtocolError = 1002,
|
||||
UnknownDataType = 1003,
|
||||
Reserved1 = 1004,
|
||||
NoStatusCodePresent = 1005,
|
||||
ConnectionClosedWithoutCloseFrame = 1006,
|
||||
InvalidMessageData = 1007,
|
||||
PolicyViolation = 1008,
|
||||
MessageTooLarge = 1009,
|
||||
ExtensionNotFound = 1010,
|
||||
UnexpectedError = 1011,
|
||||
TlsHandshakeFailure = 1015,
|
||||
};
|
||||
|
||||
class WebSocketTextFrame;
|
||||
class WebSocketBinaryFrame;
|
||||
class WebSocketPeerCloseFrame;
|
||||
|
||||
namespace _detail {
|
||||
class WebSocketImplementation;
|
||||
}
|
||||
/** @brief Statistics about data sent and received by the WebSocket.
|
||||
*
|
||||
* @remarks This class is primarily intended for test collateral and debugging to allow
|
||||
* a caller to determine information about the status of a WebSocket.
|
||||
*
|
||||
* Note: Some of these statistics are not available if the underlying transport supports native
|
||||
* websockets.
|
||||
*/
|
||||
struct WebSocketStatistics
|
||||
{
|
||||
/** @brief The number of WebSocket frames sent on this WebSocket. */
|
||||
uint32_t FramesSent;
|
||||
|
||||
/** @brief The number of bytes of data sent to the peer on this WebSocket. */
|
||||
uint32_t BytesSent;
|
||||
|
||||
/** @brief The number of WebSocket frames received from the peer. */
|
||||
uint32_t FramesReceived;
|
||||
|
||||
/** @brief The number of bytes received from the peer. */
|
||||
uint32_t BytesReceived;
|
||||
|
||||
/** @brief The number of "Ping" frames received from the peer. */
|
||||
uint32_t PingFramesReceived;
|
||||
|
||||
/** @brief The number of "Ping" frames sent to the peer. */
|
||||
uint32_t PingFramesSent;
|
||||
|
||||
/** @brief The number of "Pong" frames received from the peer. */
|
||||
uint32_t PongFramesReceived;
|
||||
|
||||
/** @brief The number of "Pong" frames sent to the peer. */
|
||||
uint32_t PongFramesSent;
|
||||
|
||||
/** @brief The number of "Text" frames received from the peer. */
|
||||
uint32_t TextFramesReceived;
|
||||
|
||||
/** @brief The number of "Text" frames sent to the peer. */
|
||||
uint32_t TextFramesSent;
|
||||
|
||||
/** @brief The number of "Binary" frames received from the peer. */
|
||||
uint32_t BinaryFramesReceived;
|
||||
|
||||
/** @brief The number of "Binary" frames sent to the peer. */
|
||||
uint32_t BinaryFramesSent;
|
||||
|
||||
/** @brief The number of "Continuation" frames sent to the peer. */
|
||||
uint32_t ContinuationFramesSent;
|
||||
|
||||
/** @brief The number of "Continuation" frames received from the peer. */
|
||||
uint32_t ContinuationFramesReceived;
|
||||
|
||||
/** @brief The number of "Close" frames received from the peer. */
|
||||
uint32_t CloseFramesReceived;
|
||||
|
||||
/** @brief The number of frames received which were not processed. */
|
||||
uint32_t FramesDropped;
|
||||
|
||||
/** @brief The number of frames received which were not returned because they were received
|
||||
* after the Close() method was called. */
|
||||
|
||||
uint32_t FramesDroppedByClose;
|
||||
/** @brief The number of frames dropped because they were over the maximum payload size. */
|
||||
|
||||
uint32_t FramesDroppedByPayloadSizeLimit;
|
||||
/** @brief The number of frames dropped because they were out of compliance with the protocol.
|
||||
*/
|
||||
uint32_t FramesDroppedByProtocolError;
|
||||
|
||||
/** @brief The number of reads performed on the transport.*/
|
||||
uint32_t TransportReads;
|
||||
|
||||
/** @brief The number of bytes read from the transport. */
|
||||
uint32_t TransportReadBytes;
|
||||
};
|
||||
|
||||
/** @brief A frame of data received from a WebSocket.
|
||||
*/
|
||||
class WebSocketFrame {
|
||||
public:
|
||||
/** @brief The type of frame received: Text, Binary or Close. */
|
||||
WebSocketFrameType FrameType{};
|
||||
|
||||
/** @brief True if the frame received is a "final" frame */
|
||||
bool IsFinalFrame{false};
|
||||
|
||||
/** @brief Returns the contents of the frame as a Text frame.
|
||||
* @returns A WebSocketTextFrame containing the contents of the frame.
|
||||
*/
|
||||
std::shared_ptr<WebSocketTextFrame> AsTextFrame();
|
||||
|
||||
/** @brief Returns the contents of the frame as a Binary frame.
|
||||
* @returns A WebSocketBinaryFrame containing the contents of the frame.
|
||||
*/
|
||||
|
||||
std::shared_ptr<WebSocketBinaryFrame> AsBinaryFrame();
|
||||
/** @brief Returns the contents of the frame as a Peer Close frame.
|
||||
* @returns A WebSocketPeerCloseFrame containing the contents of the frame.
|
||||
*/
|
||||
std::shared_ptr<WebSocketPeerCloseFrame> AsPeerCloseFrame();
|
||||
|
||||
/** @brief Construct a new instance of a WebSocketFrame.*/
|
||||
WebSocketFrame() = default;
|
||||
|
||||
/** @brief Construct a new instance of a WebSocketFrame with a specific frame type.
|
||||
* @param frameType The type of frame received.
|
||||
*/
|
||||
WebSocketFrame(WebSocketFrameType frameType) : FrameType{frameType} {}
|
||||
|
||||
/** @brief Construct a new instance of a WebSocketFrame with a specific frame type and final
|
||||
* flag.
|
||||
* @param frameType The type of frame received.
|
||||
* @param isFinalFrame true if the frame is the final frame.
|
||||
*/
|
||||
WebSocketFrame(WebSocketFrameType frameType, bool isFinalFrame)
|
||||
: FrameType{frameType}, IsFinalFrame{isFinalFrame}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/** @brief Contains the contents of a WebSocket Text frame.*/
|
||||
class WebSocketTextFrame : public WebSocketFrame,
|
||||
public std::enable_shared_from_this<WebSocketTextFrame> {
|
||||
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
|
||||
|
||||
private:
|
||||
public:
|
||||
/** @brief Constructs a new WebSocketTextFrame */
|
||||
WebSocketTextFrame() : WebSocketFrame(WebSocketFrameType::TextFrameReceived){};
|
||||
|
||||
/** @brief Text of the frame received from the remote peer. */
|
||||
std::string Text;
|
||||
|
||||
private:
|
||||
/** @brief Constructs a new WebSocketTextFrame
|
||||
* @param isFinalFrame True if this is the final frame in a multi-frame message.
|
||||
* @param body UTF-8 encoded text of the frame data.
|
||||
* @param size Length in bytes of the frame body.
|
||||
*/
|
||||
WebSocketTextFrame(bool isFinalFrame, uint8_t const* body, size_t size)
|
||||
: WebSocketFrame{WebSocketFrameType::TextFrameReceived, isFinalFrame},
|
||||
Text(body, body + size)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/** @brief Contains the contents of a WebSocket Binary frame.*/
|
||||
class WebSocketBinaryFrame : public WebSocketFrame,
|
||||
public std::enable_shared_from_this<WebSocketBinaryFrame> {
|
||||
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
|
||||
|
||||
private:
|
||||
public:
|
||||
/** @brief Constructs a new WebSocketBinaryFrame */
|
||||
WebSocketBinaryFrame() : WebSocketFrame(WebSocketFrameType::BinaryFrameReceived){};
|
||||
|
||||
/** @brief Binary frame data received from the remote peer. */
|
||||
std::vector<uint8_t> Data;
|
||||
|
||||
/** @brief Constructs a new WebSocketBinaryFrame
|
||||
* @param isFinal True if this is the final frame in a multi-frame message.
|
||||
* @param body binary of the frame data.
|
||||
* @param size Length in bytes of the frame body.
|
||||
*/
|
||||
private:
|
||||
WebSocketBinaryFrame(bool isFinal, uint8_t const* body, size_t size)
|
||||
: WebSocketFrame{WebSocketFrameType::BinaryFrameReceived, isFinal},
|
||||
Data(body, body + size)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/** @brief Contains the contents of a WebSocket Close frame.*/
|
||||
class WebSocketPeerCloseFrame : public WebSocketFrame,
|
||||
public std::enable_shared_from_this<WebSocketPeerCloseFrame> {
|
||||
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
|
||||
|
||||
public:
|
||||
/** @brief Constructs a new WebSocketPeerCloseFrame */
|
||||
WebSocketPeerCloseFrame() : WebSocketFrame(WebSocketFrameType::PeerClosedReceived){};
|
||||
|
||||
/** @brief Status code sent from the remote peer. Typically a member of the WebSocketErrorCode
|
||||
* enumeration */
|
||||
uint16_t RemoteStatusCode{};
|
||||
|
||||
/** @brief Optional text sent from the remote peer. */
|
||||
std::string RemoteCloseReason;
|
||||
|
||||
private:
|
||||
/** @brief Constructs a new WebSocketBinaryFrame
|
||||
* @param remoteStatusCode Status code sent by the remote peer.
|
||||
* @param remoteCloseReason Optional reason sent by the remote peer.
|
||||
*/
|
||||
WebSocketPeerCloseFrame(uint16_t remoteStatusCode, std::string const& remoteCloseReason)
|
||||
: WebSocketFrame{WebSocketFrameType::PeerClosedReceived},
|
||||
RemoteStatusCode(remoteStatusCode), RemoteCloseReason(remoteCloseReason)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct WebSocketOptions : Azure::Core::_internal::ClientOptions
|
||||
{
|
||||
/**
|
||||
* @brief The set of protocols which are supported by this client
|
||||
*/
|
||||
std::vector<std::string> Protocols = {};
|
||||
|
||||
/**
|
||||
* @brief The protocol name of the service client. Used for the User-Agent header
|
||||
* in the initial WebSocket handshake.
|
||||
*/
|
||||
std::string ServiceName;
|
||||
|
||||
/**
|
||||
* @brief The version of the service client. Used for the User-Agent header in the
|
||||
* initial WebSocket handshake
|
||||
*/
|
||||
std::string ServiceVersion;
|
||||
|
||||
/**
|
||||
* @brief The period of time between ping operations, default is 60 seconds.
|
||||
*/
|
||||
std::chrono::duration<int64_t> PingInterval{std::chrono::seconds{60}};
|
||||
|
||||
/**
|
||||
* @brief Construct an instance of a WebSocketOptions type.
|
||||
*
|
||||
* @param protocols Supported protocols for this websocket client.
|
||||
*/
|
||||
explicit WebSocketOptions(std::vector<std::string> protocols)
|
||||
: Azure::Core::_internal::ClientOptions{}, Protocols(protocols)
|
||||
{
|
||||
}
|
||||
WebSocketOptions() = default;
|
||||
};
|
||||
|
||||
class WebSocket {
|
||||
public:
|
||||
/** @brief Constructs a new instance of a WebSocket with the specified WebSocket options.
|
||||
*
|
||||
* @param remoteUrl The URL of the remote WebSocket server.
|
||||
* @param options The options to use for the WebSocket.
|
||||
*/
|
||||
explicit WebSocket(
|
||||
Azure::Core::Url const& remoteUrl,
|
||||
WebSocketOptions const& options = WebSocketOptions{});
|
||||
|
||||
/** @brief Destroys an instance of a WebSocket.
|
||||
*/
|
||||
~WebSocket();
|
||||
|
||||
/** @brief Opens a WebSocket connection to a remote server.
|
||||
*
|
||||
* @param context Context for the operation, used for cancellation and timeout.
|
||||
*/
|
||||
void Open(Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief Closes a WebSocket connection to the remote server gracefully.
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
void Close(Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief Closes a WebSocket connection to the remote server with additional context.
|
||||
*
|
||||
* @param closeStatus 16 bit WebSocket error code.
|
||||
* @param closeReason String describing the reason for closing the socket.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
void Close(
|
||||
uint16_t closeStatus,
|
||||
std::string const& closeReason = {},
|
||||
Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief Sends a String frame to the remote server.
|
||||
*
|
||||
* @param textFrame UTF-8 encoded text to send.
|
||||
* @param isFinalFrame if True, this is the final frame in a multi-frame message.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
void SendFrame(
|
||||
std::string const& textFrame,
|
||||
bool isFinalFrame = false,
|
||||
Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief Sends a Binary frame to the remote server.
|
||||
*
|
||||
* @param binaryFrame Binary data to send.
|
||||
* @param isFinalFrame if True, this is the final frame in a multi-frame message.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
void SendFrame(
|
||||
std::vector<uint8_t> const& binaryFrame,
|
||||
bool isFinalFrame = false,
|
||||
Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief Receive a frame from the remote server.
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
* @returns The received WebSocket frame.
|
||||
*
|
||||
*/
|
||||
std::shared_ptr<WebSocketFrame> ReceiveFrame(
|
||||
Azure::Core::Context const& context = Azure::Core::Context{});
|
||||
|
||||
/** @brief AddHeader - Adds a header to the initial handshake.
|
||||
*
|
||||
* @note This API is ignored after the WebSocket is opened.
|
||||
*
|
||||
* @param headerName Name of header to add to the initial handshake request.
|
||||
* @param headerValue Value of header to add.
|
||||
*/
|
||||
void AddHeader(std::string const& headerName, std::string const& headerValue);
|
||||
|
||||
/** @brief Determine if the WebSocket is open.
|
||||
*
|
||||
* @returns true if the WebSocket is open, false otherwise.
|
||||
*/
|
||||
bool IsOpen() const;
|
||||
|
||||
/** @brief Returns "true" if the configured websocket transport
|
||||
* supports websockets in the transport, or if the websocket implementation
|
||||
* is providing websocket protocol support.
|
||||
*
|
||||
* @returns true if the HTTP transport used for WebSocket support directly supports the
|
||||
* WebSocket API.
|
||||
*/
|
||||
bool HasBuiltInWebSocketSupport() const;
|
||||
|
||||
/** @brief Returns the protocol chosen by the remote server during the initial handshake.
|
||||
*
|
||||
* @returns The protocol negotiated between client and server.
|
||||
*/
|
||||
std::string const& GetNegotiatedProtocol() const;
|
||||
|
||||
/** @brief Returns statistics about the WebSocket.
|
||||
*
|
||||
* @returns The statistics about the WebSocket.
|
||||
*/
|
||||
WebSocketStatistics GetStatistics() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Azure::Core::Http::WebSockets::_detail::WebSocketImplementation>
|
||||
m_socketImplementation;
|
||||
};
|
||||
} // namespace _internal
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
@ -0,0 +1,204 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief Utilities to be used by HTTP WebSocket transport implementations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "azure/core/context.hpp"
|
||||
#include "azure/core/http/http.hpp"
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
|
||||
/**
|
||||
* @brief Base class for all WebSocket transport implementations.
|
||||
*/
|
||||
class WebSocketTransport {
|
||||
public:
|
||||
/**
|
||||
* @brief Web Socket Frame type, one of Text or Binary.
|
||||
*/
|
||||
enum class NativeWebSocketFrameType
|
||||
{
|
||||
/**
|
||||
* @brief Indicates that the frame is a partial UTF-8 encoded text frame - it is NOT the
|
||||
* complete frame to be sent to the remote node.
|
||||
*/
|
||||
TextFragment,
|
||||
/**
|
||||
* @brief Indicates that the frame is either the complete UTF-8 encoded text frame to be sent
|
||||
* to the remote node or the final frame of a multipart message.
|
||||
*/
|
||||
Text,
|
||||
/**
|
||||
* @brief Indicates that the frame is either the complete binary frame to be sent
|
||||
* to the remote node or the final frame of a multipart message.
|
||||
*/
|
||||
Binary,
|
||||
/**
|
||||
* @brief Indicates that the frame is a partial binary frame - it is NOT the
|
||||
* complete frame to be sent to the remote node.
|
||||
*/
|
||||
BinaryFragment,
|
||||
|
||||
/**
|
||||
* @brief Indicates that the frame is a "close" frame - the remote node
|
||||
* sent a close frame.
|
||||
*/
|
||||
Closed,
|
||||
};
|
||||
|
||||
/** @brief Close information returned from a WebSocket transport that has builtin support
|
||||
* for WebSockets.
|
||||
*/
|
||||
struct NativeWebSocketCloseInformation
|
||||
{
|
||||
/**
|
||||
* @brief Close response code.
|
||||
*/
|
||||
uint16_t CloseReason;
|
||||
/**
|
||||
* @brief Close reason.
|
||||
*/
|
||||
std::string CloseReasonDescription;
|
||||
};
|
||||
/** @brief Frame information returned from a WebSocket transport that has builtin support
|
||||
* for WebSockets.
|
||||
*/
|
||||
struct NativeWebSocketReceiveInformation
|
||||
{
|
||||
/**
|
||||
* @brief Type of frame received.
|
||||
*/
|
||||
NativeWebSocketFrameType FrameType;
|
||||
/**
|
||||
* @brief Data received.
|
||||
*/
|
||||
std::vector<uint8_t> FrameData;
|
||||
};
|
||||
/**
|
||||
* @brief Destructs `%WebSocketTransport`.
|
||||
*
|
||||
*/
|
||||
virtual ~WebSocketTransport() {}
|
||||
|
||||
/**
|
||||
* @brief Indicates whether the transport natively supports WebSockets.
|
||||
*
|
||||
* @returns true if the transport has native websocket support, false otherwise.
|
||||
*/
|
||||
virtual bool HasBuiltInWebSocketSupport() = 0;
|
||||
|
||||
/**
|
||||
* @brief Closes the WebSocket.
|
||||
*
|
||||
* Does not notify the remote endpoint that the socket is being closed.
|
||||
*
|
||||
*/
|
||||
virtual void Close() = 0;
|
||||
|
||||
/**************/
|
||||
/* Native WebSocket support functions*/
|
||||
/**************/
|
||||
/**
|
||||
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
|
||||
*
|
||||
* @param status Status value to be sent to the remote node. Application defined.
|
||||
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
virtual void NativeCloseSocket(
|
||||
uint16_t status,
|
||||
std::string const& disconnectReason,
|
||||
Azure::Core::Context const& context)
|
||||
= 0;
|
||||
|
||||
/**
|
||||
* @brief Retrieve the information associated with a WebSocket close response.
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
* @returns a tuple containing the status code and string.
|
||||
*/
|
||||
virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
|
||||
Azure::Core::Context const& context)
|
||||
= 0;
|
||||
|
||||
/**
|
||||
* @brief Send a frame of data to the remote node.
|
||||
*
|
||||
* @param frameType Frame type sent to the server, Text or Binary.
|
||||
* @param frameData Frame data to be sent to the server.
|
||||
* @param context Context for the operation.
|
||||
*/
|
||||
virtual void NativeSendFrame(
|
||||
NativeWebSocketFrameType frameType,
|
||||
std::vector<uint8_t> const& frameData,
|
||||
Azure::Core::Context const& context)
|
||||
= 0;
|
||||
|
||||
/**
|
||||
* @brief Receive a frame from the remote WebSocket server.
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
* @returns a tuple containing the Frame data received from the remote server and the type of
|
||||
* data returned from the remote endpoint
|
||||
*/
|
||||
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
|
||||
Azure::Core::Context const& context)
|
||||
= 0;
|
||||
|
||||
/**************/
|
||||
/* Non Native WebSocket support functions */
|
||||
/**************/
|
||||
|
||||
/**
|
||||
* @brief This function is used when working with streams to pull more data from the wire.
|
||||
* Function will try to keep pulling data from socket until the buffer is all written or until
|
||||
* there is no more data to get from the socket.
|
||||
*
|
||||
*/
|
||||
virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) = 0;
|
||||
|
||||
/**
|
||||
* @brief This method will use the raw socket to write all the bytes from buffer.
|
||||
*
|
||||
*/
|
||||
virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) = 0;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Constructs a default instance of `%WebSocketTransport`.
|
||||
*
|
||||
*/
|
||||
WebSocketTransport() = default;
|
||||
|
||||
/**
|
||||
* @brief Constructs `%HttpTransport` by copying another instance of `%HttpTransport`.
|
||||
*
|
||||
* @param other An instance to copy.
|
||||
*/
|
||||
WebSocketTransport(const WebSocketTransport& other) = default;
|
||||
|
||||
/**
|
||||
* @brief Constructs a WebSocketTransport from another WebSocketTransport.
|
||||
*
|
||||
* @param other An instance to move in.
|
||||
*/
|
||||
WebSocketTransport(WebSocketTransport&& other) = default;
|
||||
|
||||
/**
|
||||
* @brief Assigns one WebSocketTransport to another.
|
||||
*
|
||||
* @param other An instance to assign.
|
||||
*
|
||||
* @return A reference to this instance.
|
||||
*/
|
||||
WebSocketTransport& operator=(const WebSocketTransport& other) = default;
|
||||
};
|
||||
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
@ -0,0 +1,140 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via WInHTTP.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "azure/core/context.hpp"
|
||||
#include "azure/core/http/http.hpp"
|
||||
#include "azure/core/http/transport.hpp"
|
||||
#include "azure/core/http/websockets/websockets_transport.hpp"
|
||||
#include "azure/core/http/win_http_transport.hpp"
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
|
||||
/**
|
||||
* @brief Concrete implementation of a WebSocket Transport that uses WinHTTP.
|
||||
*/
|
||||
class WinHttpWebSocketTransport : public WebSocketTransport, public WinHttpTransport {
|
||||
|
||||
Azure::Core::Http::_detail::unique_HINTERNET m_socketHandle;
|
||||
std::mutex m_sendMutex;
|
||||
std::mutex m_receiveMutex;
|
||||
|
||||
// Called by the
|
||||
void OnUpgradedConnection(
|
||||
Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) override;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new WinHTTP WebSocket Transport.
|
||||
*
|
||||
* @param options Optional parameter to override the default options.
|
||||
*/
|
||||
WinHttpWebSocketTransport(WinHttpTransportOptions const& options = WinHttpTransportOptions())
|
||||
: WinHttpTransport(options)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse
|
||||
*
|
||||
* @param request an HTTP Request to be send.
|
||||
* @param context A context to control the request lifetime.
|
||||
*
|
||||
* @return unique ptr to an HTTP RawResponse.
|
||||
*/
|
||||
virtual std::unique_ptr<RawResponse> Send(Request& request, Context const& context) override;
|
||||
|
||||
/**
|
||||
* @brief Indicates if the transports natively websockets or not.
|
||||
*
|
||||
* @details For the WinHTTP websocket transport, the WinHTTP API supports websockets.
|
||||
*/
|
||||
virtual bool HasBuiltInWebSocketSupport() override { return true; }
|
||||
|
||||
/**
|
||||
* @brief Close the underlying WebSocket handle.
|
||||
*
|
||||
*/
|
||||
virtual void Close() override;
|
||||
|
||||
// Native WebSocket support methods.
|
||||
/**
|
||||
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
* @param status Status value to be sent to the remote node. Application defined.
|
||||
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
*/
|
||||
virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&)
|
||||
override;
|
||||
|
||||
/**
|
||||
* @brief Retrieve the information associated with a WebSocket close response.
|
||||
*
|
||||
* Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
* @returns a tuple containing the status code and string.
|
||||
*/
|
||||
virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
|
||||
Azure::Core::Context const& context) override;
|
||||
|
||||
/**
|
||||
* @brief Send a frame of data to the remote node.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native
|
||||
* websockets.
|
||||
*
|
||||
* @brief frameType Frame type sent to the server, Text or Binary.
|
||||
* @brief frameData Frame data to be sent to the server.
|
||||
*/
|
||||
virtual void NativeSendFrame(
|
||||
NativeWebSocketFrameType,
|
||||
std::vector<uint8_t> const&,
|
||||
Azure::Core::Context const&) override;
|
||||
|
||||
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
|
||||
Azure::Core::Context const&) override;
|
||||
|
||||
// Non-Native WebSocket support.
|
||||
/**
|
||||
* @brief This function is used when working with streams to pull more data from the wire.
|
||||
* Function will try to keep pulling data from socket until the buffer is all written or
|
||||
* until there is no more data to get from the socket.
|
||||
*
|
||||
* @details Not implemented for WinHTTP websockets because WinHTTP implements websockets
|
||||
* natively.
|
||||
*/
|
||||
virtual size_t ReadFromSocket(uint8_t*, size_t, Context const&) override
|
||||
{
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This method will use sockets to write all the bytes from buffer.
|
||||
*
|
||||
* @details Not implemented for WinHTTP websockets because WinHTTP implements websockets
|
||||
* natively.
|
||||
*
|
||||
*/
|
||||
virtual int SendBuffer(uint8_t const*, size_t, Context const&) override
|
||||
{
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
|
||||
bool HasWebSocketSupport() const override { return true; }
|
||||
};
|
||||
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
82
sdk/core/azure-core/src/http/curl/curl_websockets.cpp
Normal file
82
sdk/core/azure-core/src/http/curl/curl_websockets.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "azure/core/http/http.hpp"
|
||||
#include "azure/core/http/policies/policy.hpp"
|
||||
#include "azure/core/http/transport.hpp"
|
||||
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
|
||||
#include "azure/core/internal/diagnostics/log.hpp"
|
||||
#include "azure/core/platform.hpp"
|
||||
|
||||
// Private include
|
||||
#include "curl_connection_private.hpp"
|
||||
|
||||
#if defined(AZ_PLATFORM_POSIX)
|
||||
#include <poll.h> // for poll()
|
||||
#include <sys/socket.h> // for socket shutdown
|
||||
#elif defined(AZ_PLATFORM_WINDOWS)
|
||||
#if !defined(WIN32_LEAN_AND_MEAN)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#endif
|
||||
#if !defined(NOMINMAX)
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <winapifamily.h>
|
||||
#include <winsock2.h> // for WSAPoll();
|
||||
#endif
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
|
||||
void CurlWebSocketTransport::Close() { m_upgradedConnection->Shutdown(); }
|
||||
|
||||
// Send an HTTP request to the remote server.
|
||||
std::unique_ptr<RawResponse> CurlWebSocketTransport::Send(
|
||||
Request& request,
|
||||
Context const& context)
|
||||
{
|
||||
// CURL doesn't understand the ws and wss protocols, so change the URL to be http based.
|
||||
std::string requestScheme(request.GetUrl().GetScheme());
|
||||
if (requestScheme == "wss" || requestScheme == "ws")
|
||||
{
|
||||
if (requestScheme == "wss")
|
||||
{
|
||||
request.GetUrl().SetScheme("https");
|
||||
}
|
||||
else
|
||||
{
|
||||
request.GetUrl().SetScheme("http");
|
||||
}
|
||||
}
|
||||
return CurlTransport::Send(request, context);
|
||||
}
|
||||
|
||||
size_t CurlWebSocketTransport::ReadFromSocket(
|
||||
uint8_t* buffer,
|
||||
size_t bufferSize,
|
||||
Context const& context)
|
||||
{
|
||||
return m_upgradedConnection->ReadFromSocket(buffer, bufferSize, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This method will use libcurl socket to write all the bytes from buffer.
|
||||
*
|
||||
*/
|
||||
int CurlWebSocketTransport::SendBuffer(
|
||||
uint8_t const* buffer,
|
||||
size_t bufferSize,
|
||||
Context const& context)
|
||||
{
|
||||
return m_upgradedConnection->SendBuffer(buffer, bufferSize, context);
|
||||
}
|
||||
|
||||
void CurlWebSocketTransport::OnUpgradedConnection(
|
||||
std::unique_ptr<CurlNetworkConnection>&& upgradedConnection)
|
||||
{
|
||||
// Note that m_upgradedConnection is a std::shared_ptr. We define it as a std::shared_ptr
|
||||
// because a std::shared_ptr can be declared on an incomplete type, while a std::unique_ptr
|
||||
// cannot.
|
||||
m_upgradedConnection = std::move(upgradedConnection);
|
||||
}
|
||||
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
106
sdk/core/azure-core/src/http/websockets/websockets.cpp
Normal file
106
sdk/core/azure-core/src/http/websockets/websockets.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "azure/core/http/websockets/websockets.hpp"
|
||||
#include "azure/core/context.hpp"
|
||||
#include "websockets_impl.hpp"
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _internal {
|
||||
|
||||
WebSocket::WebSocket(Azure::Core::Url const& remoteUrl, WebSocketOptions const& options)
|
||||
: m_socketImplementation(
|
||||
std::make_unique<Azure::Core::Http::WebSockets::_detail::WebSocketImplementation>(
|
||||
remoteUrl,
|
||||
options))
|
||||
|
||||
{
|
||||
}
|
||||
WebSocket::~WebSocket() {}
|
||||
|
||||
void WebSocket::Open(Azure::Core::Context const& context)
|
||||
{
|
||||
m_socketImplementation->Open(context);
|
||||
}
|
||||
void WebSocket::Close(Azure::Core::Context const& context)
|
||||
{
|
||||
m_socketImplementation->Close(
|
||||
static_cast<uint16_t>(WebSocketErrorCode::EndpointDisappearing), {}, context);
|
||||
}
|
||||
void WebSocket::Close(
|
||||
uint16_t closeStatus,
|
||||
std::string const& closeReason,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
m_socketImplementation->Close(closeStatus, closeReason, context);
|
||||
}
|
||||
|
||||
void WebSocket::SendFrame(
|
||||
std::string const& textFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
m_socketImplementation->SendFrame(textFrame, isFinalFrame, context);
|
||||
}
|
||||
|
||||
void WebSocket::SendFrame(
|
||||
std::vector<uint8_t> const& binaryFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
m_socketImplementation->SendFrame(binaryFrame, isFinalFrame, context);
|
||||
}
|
||||
|
||||
WebSocketStatistics WebSocket::GetStatistics() const
|
||||
{
|
||||
return m_socketImplementation->GetStatistics();
|
||||
}
|
||||
|
||||
bool WebSocket::HasBuiltInWebSocketSupport() const
|
||||
{
|
||||
return m_socketImplementation->HasBuiltInWebSocketSupport();
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketFrame> WebSocket::ReceiveFrame(Azure::Core::Context const& context)
|
||||
{
|
||||
return m_socketImplementation->ReceiveFrame(context);
|
||||
}
|
||||
|
||||
void WebSocket::AddHeader(std::string const& headerName, std::string const& headerValue)
|
||||
{
|
||||
m_socketImplementation->AddHeader(headerName, headerValue);
|
||||
}
|
||||
std::string const& WebSocket::GetNegotiatedProtocol() const
|
||||
{
|
||||
return m_socketImplementation->GetNegotiatedProtocol();
|
||||
}
|
||||
|
||||
bool WebSocket::IsOpen() const { return m_socketImplementation->IsOpen(); }
|
||||
|
||||
std::shared_ptr<WebSocketTextFrame> WebSocketFrame::AsTextFrame()
|
||||
{
|
||||
if (FrameType != WebSocketFrameType::TextFrameReceived)
|
||||
{
|
||||
throw std::logic_error("Cannot cast to TextFrameReceived.");
|
||||
}
|
||||
return static_cast<WebSocketTextFrame*>(this)->shared_from_this();
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketBinaryFrame> WebSocketFrame::AsBinaryFrame()
|
||||
{
|
||||
if (FrameType != WebSocketFrameType::BinaryFrameReceived)
|
||||
{
|
||||
throw std::logic_error("Cannot cast to BinaryFrameReceived.");
|
||||
}
|
||||
return static_cast<WebSocketBinaryFrame*>(this)->shared_from_this();
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketPeerCloseFrame> WebSocketFrame::AsPeerCloseFrame()
|
||||
{
|
||||
if (FrameType != WebSocketFrameType::PeerClosedReceived)
|
||||
{
|
||||
throw std::logic_error("Cannot cast to PeerClose.");
|
||||
}
|
||||
return static_cast<WebSocketPeerCloseFrame*>(this)->shared_from_this();
|
||||
}
|
||||
|
||||
}}}}} // namespace Azure::Core::Http::WebSockets::_internal
|
||||
876
sdk/core/azure-core/src/http/websockets/websockets_impl.cpp
Normal file
876
sdk/core/azure-core/src/http/websockets/websockets_impl.cpp
Normal file
@ -0,0 +1,876 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "websockets_impl.hpp"
|
||||
#include "azure/core/base64.hpp"
|
||||
#include "azure/core/http/policies/policy.hpp"
|
||||
#include "azure/core/internal/cryptography/sha_hash.hpp"
|
||||
// SUPPORT_NATIVE_TRANSPORT indicates if WinHTTP should be compiled with native transport support
|
||||
// or not.
|
||||
// Note that this is primarily required to improve the code coverage numbers in the CI pipeline.
|
||||
#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER)
|
||||
#include "azure/core/http/websockets/win_http_websockets_transport.hpp"
|
||||
#define SUPPORT_NATIVE_TRANSPORT 1
|
||||
#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
|
||||
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
|
||||
#define SUPPORT_NATIVE_TRANSPORT 0
|
||||
#endif
|
||||
#include "azure/core/internal/diagnostics/log.hpp"
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <iomanip>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail {
|
||||
using namespace Azure::Core::Http::WebSockets::_internal;
|
||||
using namespace Azure::Core::Diagnostics::_internal;
|
||||
using namespace Azure::Core::Diagnostics;
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
namespace {
|
||||
std::string HexEncode(std::vector<uint8_t> const& data, size_t length)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < std::min(data.size(), length); i++)
|
||||
{
|
||||
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(data[i]);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
WebSocketImplementation::WebSocketImplementation(
|
||||
Azure::Core::Url const& remoteUrl,
|
||||
WebSocketOptions const& options)
|
||||
: m_remoteUrl(remoteUrl), m_options(options), m_pingThread(this, m_options.PingInterval)
|
||||
{
|
||||
}
|
||||
|
||||
void WebSocketImplementation::Open(Azure::Core::Context const& context)
|
||||
{
|
||||
if (m_state != SocketState::Invalid && m_state != SocketState::Closed)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket in unexpected state: " + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
m_state = SocketState::Opening;
|
||||
|
||||
#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER)
|
||||
WinHttpTransportOptions transportOptions;
|
||||
auto winHttpTransport
|
||||
= std::make_shared<Azure::Core::Http::WebSockets::WinHttpWebSocketTransport>(
|
||||
transportOptions);
|
||||
m_transport = std::static_pointer_cast<WebSocketTransport>(winHttpTransport);
|
||||
m_options.Transport.Transport = std::static_pointer_cast<HttpTransport>(winHttpTransport);
|
||||
#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
|
||||
CurlWebSocketTransportOptions transportOptions;
|
||||
transportOptions.HttpKeepAlive = false;
|
||||
auto curlWebSockets
|
||||
= std::make_shared<Azure::Core::Http::WebSockets::CurlWebSocketTransport>(transportOptions);
|
||||
|
||||
m_transport = std::static_pointer_cast<WebSocketTransport>(curlWebSockets);
|
||||
m_options.Transport.Transport = std::static_pointer_cast<HttpTransport>(curlWebSockets);
|
||||
#endif
|
||||
|
||||
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> perCallPolicies{};
|
||||
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> perRetryPolicies{};
|
||||
// If the caller has told us a service name, add the telemetry policy to the pipeline to add
|
||||
// Create pipeline for WebSocket handshake
|
||||
std::string serviceName = m_options.ServiceName.empty() ? "azure.core.websockets" : m_options.ServiceName;
|
||||
std::string serviceVersion = m_options.ServiceVersion.empty() ? "1.0.0" : m_options.ServiceVersion;
|
||||
|
||||
Azure::Core::Http::_internal::HttpPipeline openPipeline(
|
||||
m_options, serviceName, serviceVersion, std::move(perRetryPolicies), std::move(perCallPolicies));
|
||||
|
||||
Azure::Core::Http::Request openSocketRequest(
|
||||
Azure::Core::Http::HttpMethod::Get, m_remoteUrl, false);
|
||||
|
||||
// Generate the random request key. Only used when the transport doesn't support websockets
|
||||
// natively.
|
||||
auto randomKey = GenerateRandomKey();
|
||||
auto encodedKey = Azure::Core::Convert::Base64Encode(randomKey);
|
||||
if (!m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
// If the transport doesn't support WebSockets natively, set the standardized WebSocket
|
||||
// upgrade headers.
|
||||
openSocketRequest.SetHeader("Upgrade", "websocket");
|
||||
openSocketRequest.SetHeader("Connection", "upgrade");
|
||||
openSocketRequest.SetHeader("Sec-WebSocket-Version", "13");
|
||||
openSocketRequest.SetHeader("Sec-WebSocket-Key", encodedKey);
|
||||
}
|
||||
if (!m_options.Protocols.empty())
|
||||
{
|
||||
|
||||
std::string protocols;
|
||||
for (auto const& protocol : m_options.Protocols)
|
||||
{
|
||||
protocols += protocol;
|
||||
protocols += ", ";
|
||||
}
|
||||
protocols = protocols.substr(0, protocols.size() - 2);
|
||||
openSocketRequest.SetHeader("Sec-WebSocket-Protocol", protocols);
|
||||
}
|
||||
for (auto const& additionalHeader : m_headers)
|
||||
{
|
||||
openSocketRequest.SetHeader(additionalHeader.first, additionalHeader.second);
|
||||
}
|
||||
std::string remoteOrigin;
|
||||
remoteOrigin = m_remoteUrl.GetScheme();
|
||||
remoteOrigin += "://";
|
||||
remoteOrigin += m_remoteUrl.GetHost();
|
||||
openSocketRequest.SetHeader("Origin", remoteOrigin);
|
||||
|
||||
// Send the connect request to the WebSocket server.
|
||||
auto response = openPipeline.Send(openSocketRequest, context);
|
||||
|
||||
// Ensure that the server thinks we're switching protocols. If it doesn't,
|
||||
// fail immediately.
|
||||
if (response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::SwitchingProtocols)
|
||||
{
|
||||
throw Azure::Core::Http::TransportException("Unexpected handshake response");
|
||||
}
|
||||
|
||||
// Prove that the server received this socket request.
|
||||
auto& responseHeaders = response->GetHeaders();
|
||||
if (!m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
auto socketAccept(responseHeaders.find("Sec-WebSocket-Accept"));
|
||||
if (socketAccept == responseHeaders.end())
|
||||
{
|
||||
throw Azure::Core::Http::TransportException("Missing Sec-WebSocket-Accept header");
|
||||
}
|
||||
// Verify that the WebSocket server received *this* open request.
|
||||
else
|
||||
{
|
||||
VerifySocketAccept(encodedKey, socketAccept->second);
|
||||
}
|
||||
m_initialBodyStream = response->ExtractBodyStream();
|
||||
m_pingThread.Start(m_transport);
|
||||
}
|
||||
|
||||
// Remember the protocol that the client chose.
|
||||
auto chosenProtocol = responseHeaders.find("Sec-WebSocket-Protocol");
|
||||
if (chosenProtocol != responseHeaders.end())
|
||||
{
|
||||
m_chosenProtocol = chosenProtocol->second;
|
||||
}
|
||||
|
||||
m_state = SocketState::Open;
|
||||
}
|
||||
bool WebSocketImplementation::HasBuiltInWebSocketSupport()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
if (m_state != SocketState::Open)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
return m_transport->HasBuiltInWebSocketSupport();
|
||||
}
|
||||
|
||||
std::string const& WebSocketImplementation::GetNegotiatedProtocol()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
if (m_state != SocketState::Open)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
return m_chosenProtocol;
|
||||
}
|
||||
|
||||
void WebSocketImplementation::AddHeader(std::string const& header, std::string const& headerValue)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
if (m_state != SocketState::Closed && m_state != SocketState::Invalid)
|
||||
{
|
||||
throw std::runtime_error("AddHeader can only be called on closed sockets.");
|
||||
}
|
||||
m_headers.emplace(std::make_pair(header, headerValue));
|
||||
}
|
||||
|
||||
void WebSocketImplementation::Close(
|
||||
uint16_t closeStatus,
|
||||
std::string const& closeReason,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
|
||||
// If we're closing an already closed socket, we're done.
|
||||
if (m_state == SocketState::Closed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (m_state != SocketState::Open)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
m_state = SocketState::Closing;
|
||||
#if SUPPORT_NATIVE_TRANSPORT
|
||||
if (m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
m_transport->NativeCloseSocket(closeStatus, closeReason.c_str(), context);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// Send a going away message to the server.
|
||||
std::vector<uint8_t> closePayload;
|
||||
closePayload.push_back(closeStatus >> 8);
|
||||
closePayload.push_back(closeStatus & 0xff);
|
||||
closePayload.insert(closePayload.end(), closeReason.begin(), closeReason.end());
|
||||
std::vector<uint8_t> closeFrame = EncodeFrame(SocketOpcode::Close, true, closePayload);
|
||||
SendTransportBuffer(closeFrame, context);
|
||||
|
||||
// Unlock the state mutex before waiting for the close response to be received.
|
||||
lock.unlock();
|
||||
// Drain the incoming series of frames from the server.
|
||||
// Note that there might be in-flight frames that were sent from the other end of the
|
||||
// WebSocket that we don't care about any more (since we're closing the WebSocket). So
|
||||
// drain those frames.
|
||||
auto closeResponse = ReceiveTransportFrame(context);
|
||||
while (closeResponse && closeResponse->Opcode != SocketOpcode::Close)
|
||||
{
|
||||
m_receiveStatistics.FramesDroppedByClose++;
|
||||
Log::Write(
|
||||
Logger::Level::Warning,
|
||||
"Received unexpected frame during close. Opcode: "
|
||||
+ std::to_string(static_cast<uint8_t>(closeResponse->Opcode)));
|
||||
closeResponse = ReceiveTransportFrame(context);
|
||||
}
|
||||
|
||||
// Re-acquire the state lock once we've received the close response.
|
||||
lock.lock();
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
}
|
||||
// Close the socket - after this point, the m_transport is invalid.
|
||||
m_pingThread.Shutdown();
|
||||
m_transport->Close();
|
||||
m_state = SocketState::Closed;
|
||||
}
|
||||
|
||||
void WebSocketImplementation::SendFrame(
|
||||
std::string const& textFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
if (m_state != SocketState::Open)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
std::vector<uint8_t> utf8text(textFrame.begin(), textFrame.end());
|
||||
m_receiveStatistics.TextFramesSent++;
|
||||
#if SUPPORT_NATIVE_TRANSPORT
|
||||
if (m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
m_transport->NativeSendFrame(
|
||||
(isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Text
|
||||
: WebSocketTransport::NativeWebSocketFrameType::TextFragment),
|
||||
utf8text,
|
||||
context);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
std::vector<uint8_t> sendFrame = EncodeFrame(SocketOpcode::TextFrame, isFinalFrame, utf8text);
|
||||
SendTransportBuffer(sendFrame, context);
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocketImplementation::SendFrame(
|
||||
std::vector<uint8_t> const& binaryFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
|
||||
if (m_state != SocketState::Open)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
m_receiveStatistics.BinaryFramesSent++;
|
||||
#if SUPPORT_NATIVE_TRANSPORT
|
||||
if (m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
m_transport->NativeSendFrame(
|
||||
(isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Binary
|
||||
: WebSocketTransport::NativeWebSocketFrameType::BinaryFragment),
|
||||
binaryFrame,
|
||||
context);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
std::vector<uint8_t> sendFrame
|
||||
= EncodeFrame(SocketOpcode::BinaryFrame, isFinalFrame, binaryFrame);
|
||||
|
||||
SendTransportBuffer(sendFrame, context);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketFrame> WebSocketImplementation::ReceiveFrame(
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_stateMutex);
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
|
||||
if (m_state != SocketState::Open && m_state != SocketState::Closing)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
|
||||
}
|
||||
|
||||
// Unlock the state lock to allow other threads to run. If we don't, we might end up in in a
|
||||
// situation where the server won't respond to the this client because all the client threads
|
||||
// are blocked on the state lock.
|
||||
lock.unlock();
|
||||
|
||||
std::shared_ptr<WebSocketInternalFrame> frame;
|
||||
// Loop until we receive an returnable incoming frame.
|
||||
// If the incoming frame is returnable, we return the value from the frame.
|
||||
while (true)
|
||||
{
|
||||
frame = ReceiveTransportFrame(context);
|
||||
if (frame)
|
||||
{
|
||||
switch (frame->Opcode)
|
||||
{
|
||||
// When we receive a "ping" frame, we want to send a Pong frame back to the server.
|
||||
case SocketOpcode::Ping:
|
||||
Log::Write(
|
||||
Logger::Level::Verbose, "Received Ping frame: " + HexEncode(frame->Payload, 16));
|
||||
SendPong(frame->Payload, context);
|
||||
break;
|
||||
|
||||
// We want to ignore all incoming "Pong" frames.
|
||||
case SocketOpcode::Pong:
|
||||
Log::Write(
|
||||
Logger::Level::Verbose, "Received Pong frame: " + HexEncode(frame->Payload, 16));
|
||||
break;
|
||||
|
||||
case SocketOpcode::BinaryFrame:
|
||||
m_currentMessageType = SocketMessageType::Binary;
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketBinaryFrame(
|
||||
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
|
||||
|
||||
case SocketOpcode::TextFrame:
|
||||
m_currentMessageType = SocketMessageType::Text;
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketTextFrame(
|
||||
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
|
||||
|
||||
case SocketOpcode::Close: {
|
||||
if (frame->Payload.size() < 2)
|
||||
{
|
||||
throw std::runtime_error("Close response buffer is too short.");
|
||||
}
|
||||
// Encode the payload for close according to RFC 6455
|
||||
// section 5.5.1. The first two bytes of the payload contain the status code.
|
||||
// The remainder of the payload is a UTF-8 encoded string.
|
||||
uint16_t errorCode = 0;
|
||||
errorCode |= (frame->Payload[0] << 8) & 0xff00;
|
||||
errorCode |= (frame->Payload[1] & 0x00ff);
|
||||
|
||||
// We received a close frame, mark the socket as closed. Make sure we
|
||||
// reacquire the state lock before setting the state to closed.
|
||||
lock.lock();
|
||||
m_stateOwner = std::this_thread::get_id();
|
||||
m_state = SocketState::Closed;
|
||||
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketPeerCloseFrame(
|
||||
errorCode, std::string(frame->Payload.begin() + 2, frame->Payload.end())));
|
||||
}
|
||||
|
||||
// Continuation frames need to be treated somewhat specially.
|
||||
// We depend on the fact that the protocol requires that a Continuation frame
|
||||
// only be sent if it is part of a multi-frame message whose previous frame was a Text
|
||||
// or Binary frame.
|
||||
case SocketOpcode::Continuation:
|
||||
if (m_currentMessageType == SocketMessageType::Text)
|
||||
{
|
||||
if (frame->IsFinalFrame)
|
||||
{
|
||||
m_currentMessageType = SocketMessageType::Unknown;
|
||||
}
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketTextFrame(
|
||||
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
|
||||
}
|
||||
else if (m_currentMessageType == SocketMessageType::Binary)
|
||||
{
|
||||
if (frame->IsFinalFrame)
|
||||
{
|
||||
m_currentMessageType = SocketMessageType::Unknown;
|
||||
}
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketBinaryFrame(
|
||||
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
m_receiveStatistics.FramesDroppedByProtocolError++;
|
||||
throw std::runtime_error("Unknown message type and received continuation opcode");
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown frame type received.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (m_state != SocketState::Closed && m_state != SocketState::Closing)
|
||||
{
|
||||
throw std::runtime_error("Transport is at EOF, no frame to receive.");
|
||||
}
|
||||
// The socket was closed, most likely locally, so fake a close frame response.
|
||||
return std::shared_ptr<WebSocketFrame>(new WebSocketPeerCloseFrame());
|
||||
}
|
||||
|
||||
context.ThrowIfCancelled();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<WebSocketImplementation::WebSocketInternalFrame>
|
||||
WebSocketImplementation::ReceiveTransportFrame(Azure::Core::Context const& context)
|
||||
{
|
||||
#if SUPPORT_NATIVE_TRANSPORT
|
||||
if (m_transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
auto payload = m_transport->NativeReceiveFrame(context);
|
||||
m_receiveStatistics.FramesReceived++;
|
||||
switch (payload.FrameType)
|
||||
{
|
||||
case WebSocketTransport::NativeWebSocketFrameType::Binary:
|
||||
m_receiveStatistics.BinaryFramesReceived++;
|
||||
return std::make_shared<WebSocketInternalFrame>(
|
||||
SocketOpcode::BinaryFrame, true, payload.FrameData);
|
||||
case WebSocketTransport::NativeWebSocketFrameType::BinaryFragment:
|
||||
m_receiveStatistics.BinaryFramesReceived++;
|
||||
return std::make_shared<WebSocketInternalFrame>(
|
||||
SocketOpcode::BinaryFrame, false, payload.FrameData);
|
||||
case WebSocketTransport::NativeWebSocketFrameType::Text:
|
||||
m_receiveStatistics.TextFramesReceived++;
|
||||
return std::make_shared<WebSocketInternalFrame>(
|
||||
SocketOpcode::TextFrame, true, payload.FrameData);
|
||||
case WebSocketTransport::NativeWebSocketFrameType::TextFragment:
|
||||
m_receiveStatistics.TextFramesReceived++;
|
||||
return std::make_shared<WebSocketInternalFrame>(
|
||||
SocketOpcode::TextFrame, false, payload.FrameData);
|
||||
case WebSocketTransport::NativeWebSocketFrameType::Closed: {
|
||||
m_receiveStatistics.CloseFramesReceived++;
|
||||
auto closeResult = m_transport->NativeGetCloseSocketInformation(context);
|
||||
std::vector<uint8_t> closePayload;
|
||||
closePayload.push_back(closeResult.CloseReason >> 8);
|
||||
closePayload.push_back(closeResult.CloseReason & 0xff);
|
||||
closePayload.insert(
|
||||
closePayload.end(),
|
||||
closeResult.CloseReasonDescription.begin(),
|
||||
closeResult.CloseReasonDescription.end());
|
||||
return std::make_shared<WebSocketInternalFrame>(SocketOpcode::Close, true, closePayload);
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unexpected frame type received.");
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
std::shared_ptr<WebSocketInternalFrame> frame = DecodeFrame(context);
|
||||
if (frame)
|
||||
{
|
||||
|
||||
// Handle statistics for the incoming frame.
|
||||
m_receiveStatistics.FramesReceived++;
|
||||
switch (frame->Opcode)
|
||||
{
|
||||
case SocketOpcode::Ping: {
|
||||
m_receiveStatistics.PingFramesReceived++;
|
||||
break;
|
||||
}
|
||||
case SocketOpcode::Pong: {
|
||||
m_receiveStatistics.PongFramesReceived++;
|
||||
break;
|
||||
}
|
||||
case SocketOpcode::TextFrame: {
|
||||
m_receiveStatistics.TextFramesReceived++;
|
||||
break;
|
||||
}
|
||||
case SocketOpcode::BinaryFrame: {
|
||||
m_receiveStatistics.BinaryFramesReceived++;
|
||||
break;
|
||||
}
|
||||
case SocketOpcode::Close: {
|
||||
m_receiveStatistics.CloseFramesReceived++;
|
||||
break;
|
||||
}
|
||||
case SocketOpcode::Continuation: {
|
||||
m_receiveStatistics.ContinuationFramesReceived++;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
m_receiveStatistics.UnknownFramesReceived++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
m_receiveStatistics.FramesDropped++;
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
}
|
||||
|
||||
WebSocketStatistics WebSocketImplementation::GetStatistics() const
|
||||
{
|
||||
WebSocketStatistics returnValue{};
|
||||
returnValue.FramesSent = m_receiveStatistics.FramesSent.load();
|
||||
returnValue.FramesReceived = m_receiveStatistics.FramesReceived.load();
|
||||
returnValue.BinaryFramesReceived = m_receiveStatistics.BinaryFramesReceived.load();
|
||||
returnValue.TextFramesReceived = m_receiveStatistics.TextFramesReceived.load();
|
||||
returnValue.BinaryFramesSent = m_receiveStatistics.BinaryFramesSent.load();
|
||||
returnValue.TextFramesSent = m_receiveStatistics.TextFramesSent.load();
|
||||
returnValue.PingFramesReceived = m_receiveStatistics.PingFramesReceived.load();
|
||||
returnValue.PongFramesReceived = m_receiveStatistics.PongFramesReceived.load();
|
||||
returnValue.PingFramesSent = m_receiveStatistics.PingFramesSent.load();
|
||||
returnValue.PongFramesSent = m_receiveStatistics.PongFramesSent.load();
|
||||
|
||||
returnValue.BytesSent = m_receiveStatistics.BytesSent.load();
|
||||
returnValue.BytesReceived = m_receiveStatistics.BytesReceived.load();
|
||||
returnValue.FramesDropped = m_receiveStatistics.FramesDropped.load();
|
||||
returnValue.FramesDroppedByClose = m_receiveStatistics.FramesDroppedByClose.load();
|
||||
returnValue.FramesDroppedByPayloadSizeLimit
|
||||
= m_receiveStatistics.FramesDroppedByPayloadSizeLimit.load();
|
||||
returnValue.FramesDroppedByProtocolError
|
||||
= m_receiveStatistics.FramesDroppedByProtocolError.load();
|
||||
returnValue.TransportReadBytes = m_receiveStatistics.TransportReadBytes.load();
|
||||
returnValue.TransportReads = m_receiveStatistics.TransportReads.load();
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> WebSocketImplementation::EncodeFrame(
|
||||
SocketOpcode opcode,
|
||||
bool isFinal,
|
||||
std::vector<uint8_t> const& payload)
|
||||
{
|
||||
std::vector<uint8_t> encodedFrame;
|
||||
// Add opcode+fin.
|
||||
encodedFrame.push_back(static_cast<uint8_t>(opcode) | (isFinal ? 0x80 : 0));
|
||||
uint8_t maskAndLength = 0;
|
||||
maskAndLength |= 0x80;
|
||||
|
||||
// Payloads smaller than 125 bytes are encoded directly in the maskAndLength field.
|
||||
uint64_t payloadSize = static_cast<uint64_t>(payload.size());
|
||||
if (payloadSize <= 125)
|
||||
{
|
||||
maskAndLength |= static_cast<uint8_t>(payload.size());
|
||||
}
|
||||
else if (payloadSize <= 65535)
|
||||
{
|
||||
// Payloads greater than 125 whose size can fit in a 16 bit integer bytes
|
||||
// are encoded as a 16 bit unsigned integer in network byte order.
|
||||
maskAndLength |= 126;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Payloads greater than 65536 have their length are encoded as a 64 bit unsigned integer
|
||||
// in network byte order.
|
||||
maskAndLength |= 127;
|
||||
}
|
||||
encodedFrame.push_back(maskAndLength);
|
||||
// Encode a 16 bit length.
|
||||
if (payloadSize > 125 && payloadSize <= 65535)
|
||||
{
|
||||
encodedFrame.push_back(static_cast<uint16_t>(payload.size()) >> 8);
|
||||
encodedFrame.push_back(static_cast<uint16_t>(payload.size()) & 0xff);
|
||||
}
|
||||
// Encode a 64 bit length.
|
||||
else if (payloadSize >= 65536)
|
||||
{
|
||||
|
||||
encodedFrame.push_back((payloadSize >> 56) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 48) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 40) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 32) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 24) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 16) & 0xff);
|
||||
encodedFrame.push_back((payloadSize >> 8) & 0xff);
|
||||
encodedFrame.push_back(payloadSize & 0xff);
|
||||
}
|
||||
// Calculate the masking key. This MUST be 4 bytes of high entropy random numbers used to
|
||||
// mask the input data.
|
||||
{
|
||||
// Start by generating the mask - 4 bytes of random data.
|
||||
std::vector<uint8_t> mask = GenerateRandomBytes(4);
|
||||
|
||||
// Append the mask to the payload.
|
||||
encodedFrame.insert(encodedFrame.end(), mask.begin(), mask.end());
|
||||
|
||||
// And mask the payload before transmitting it.
|
||||
size_t index = 0;
|
||||
for (auto ch : payload)
|
||||
{
|
||||
encodedFrame.push_back(ch ^ mask[index % 4]);
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return encodedFrame;
|
||||
}
|
||||
std::shared_ptr<WebSocketImplementation::WebSocketInternalFrame>
|
||||
WebSocketImplementation::DecodeFrame(Azure::Core::Context const& context)
|
||||
{
|
||||
// Ensure single threaded access to receive this frame.
|
||||
std::unique_lock<std::mutex> lock(m_transportMutex);
|
||||
if (IsTransportEof())
|
||||
{
|
||||
throw std::runtime_error("Frame buffer is too small.");
|
||||
}
|
||||
uint8_t payloadByte = ReadTransportByte(context);
|
||||
// If the transport is at EOF, then there is no payload data, so just return null.
|
||||
if (IsTransportEof())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
SocketOpcode opcode = static_cast<SocketOpcode>(payloadByte & 0x7f);
|
||||
bool isFinal = (payloadByte & 0x80) != 0;
|
||||
payloadByte = ReadTransportByte(context);
|
||||
if (IsTransportEof())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
if (payloadByte & 0x80)
|
||||
{
|
||||
throw std::runtime_error("Server sent a frame with a reserved bit set.");
|
||||
}
|
||||
int64_t payloadLength = payloadByte & 0x7f;
|
||||
if (payloadLength <= 125)
|
||||
{
|
||||
payloadByte += 1;
|
||||
}
|
||||
else if (payloadLength == 126)
|
||||
{
|
||||
payloadLength = ReadTransportShort(context);
|
||||
}
|
||||
else if (payloadLength == 127)
|
||||
{
|
||||
payloadLength = ReadTransportInt64(context);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::logic_error("Unexpected payload length.");
|
||||
}
|
||||
if (IsTransportEof())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> payload(ReadTransportBytes(static_cast<size_t>(payloadLength), context));
|
||||
if (IsTransportEof())
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<WebSocketInternalFrame>(opcode, isFinal, payload);
|
||||
}
|
||||
|
||||
uint8_t WebSocketImplementation::ReadTransportByte(Azure::Core::Context const& context)
|
||||
{
|
||||
if (m_bufferPos >= m_bufferLen)
|
||||
{
|
||||
// Start by reading data from our initial body stream.
|
||||
m_bufferLen = m_initialBodyStream->ReadToCount(m_buffer, m_bufferSize, context);
|
||||
if (m_bufferLen == 0)
|
||||
{
|
||||
// If we run out of the initial stream, we need to read from the transport.
|
||||
m_bufferLen = m_transport->ReadFromSocket(m_buffer, m_bufferSize, context);
|
||||
m_receiveStatistics.TransportReads++;
|
||||
m_receiveStatistics.TransportReadBytes += static_cast<uint32_t>(m_bufferLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
Azure::Core::Diagnostics::_internal::Log::Write(
|
||||
Azure::Core::Diagnostics::Logger::Level::Informational,
|
||||
"Read data from initial stream");
|
||||
}
|
||||
m_bufferPos = 0;
|
||||
if (m_bufferLen == 0)
|
||||
{
|
||||
m_eof = true;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
m_receiveStatistics.BytesReceived++;
|
||||
return m_buffer[m_bufferPos++];
|
||||
}
|
||||
uint16_t WebSocketImplementation::ReadTransportShort(Azure::Core::Context const& context)
|
||||
{
|
||||
uint16_t result = ReadTransportByte(context);
|
||||
result <<= 8;
|
||||
result |= ReadTransportByte(context);
|
||||
return result;
|
||||
}
|
||||
uint64_t WebSocketImplementation::ReadTransportInt64(Azure::Core::Context const& context)
|
||||
{
|
||||
uint64_t result = 0;
|
||||
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 56 & 0xff00000000000000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 48 & 0x00ff000000000000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 40 & 0x0000ff0000000000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 32 & 0x000000ff00000000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 24 & 0x00000000ff000000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 16 & 0x0000000000ff0000);
|
||||
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 8 & 0x000000000000ff00);
|
||||
result |= static_cast<uint64_t>(ReadTransportByte(context));
|
||||
return result;
|
||||
}
|
||||
std::vector<uint8_t> WebSocketImplementation::ReadTransportBytes(
|
||||
size_t readLength,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::vector<uint8_t> result;
|
||||
size_t index = 0;
|
||||
while (index < readLength)
|
||||
{
|
||||
uint8_t byte = ReadTransportByte(context);
|
||||
result.push_back(byte);
|
||||
index += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void WebSocketImplementation::SendTransportBuffer(
|
||||
std::vector<uint8_t> const& sendFrame,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::unique_lock<std::mutex> transportLock(m_transportMutex);
|
||||
m_receiveStatistics.BytesSent += static_cast<uint32_t>(sendFrame.size());
|
||||
m_receiveStatistics.FramesSent += 1;
|
||||
m_transport->SendBuffer(sendFrame.data(), sendFrame.size(), context);
|
||||
}
|
||||
|
||||
// Verify the Sec-WebSocket-Accept header as defined in RFC 6455 Section 1.3, which defines
|
||||
// the opening handshake used for establishing the WebSocket connection.
|
||||
std::string acceptHeaderGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
void WebSocketImplementation::VerifySocketAccept(
|
||||
std::string const& encodedKey,
|
||||
std::string const& acceptHeader)
|
||||
{
|
||||
std::string concatenatedKey(encodedKey);
|
||||
concatenatedKey += acceptHeaderGuid;
|
||||
Azure::Core::Cryptography::_internal::Sha1Hash sha1hash;
|
||||
|
||||
sha1hash.Append(
|
||||
reinterpret_cast<const uint8_t*>(concatenatedKey.data()), concatenatedKey.size());
|
||||
auto keyHash = sha1hash.Final();
|
||||
std::string encodedHash = Azure::Core::Convert::Base64Encode(keyHash);
|
||||
if (encodedHash != acceptHeader)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Hash returned by WebSocket server does not match expected hash. Aborting");
|
||||
}
|
||||
}
|
||||
|
||||
WebSocketImplementation::PingThread::PingThread(
|
||||
WebSocketImplementation* socketImplementation,
|
||||
std::chrono::duration<int64_t> pingInterval)
|
||||
: m_webSocketImplementation(socketImplementation), m_pingInterval(pingInterval)
|
||||
{
|
||||
}
|
||||
void WebSocketImplementation::PingThread::Start(std::shared_ptr<WebSocketTransport> transport)
|
||||
{
|
||||
m_stop = false;
|
||||
// Spin up a thread to receive data from the transport.
|
||||
if (!transport->HasBuiltInWebSocketSupport())
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_pingThreadStarted);
|
||||
m_pingThread = std::thread{&PingThread::PingThreadLoop, this};
|
||||
m_pingThreadReady.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
WebSocketImplementation::PingThread::~PingThread()
|
||||
{
|
||||
// Ensure that the receive thread is stopped.
|
||||
Shutdown();
|
||||
}
|
||||
void WebSocketImplementation::PingThread::Shutdown()
|
||||
{
|
||||
if (m_pingThread.joinable())
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_stopMutex);
|
||||
m_stop = true;
|
||||
lock.unlock();
|
||||
m_pingThreadStopped.notify_all();
|
||||
|
||||
m_pingThread.join();
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocketImplementation::PingThread::PingThreadLoop()
|
||||
{
|
||||
Log::Write(Logger::Level::Verbose, "Start Ping Thread Loop.");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_pingThreadStarted);
|
||||
m_pingThreadReady.notify_all();
|
||||
}
|
||||
while (true)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m_stopMutex);
|
||||
if (this->m_pingThreadStopped.wait_for(lock, m_pingInterval) == std::cv_status::timeout)
|
||||
{
|
||||
Log::Write(Logger::Level::Verbose, "Send Ping to peer.");
|
||||
|
||||
// The receiveContext timed out, this means we timed out our "ping" timeout.
|
||||
// Send a "Ping" request to the remote node.
|
||||
auto pingData = GenerateRandomBytes(4);
|
||||
SendPing(pingData, Azure::Core::Context{});
|
||||
}
|
||||
if (m_stop)
|
||||
{
|
||||
Log::Write(Logger::Level::Verbose, "Exiting ping thread");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool WebSocketImplementation::PingThread::SendPing(
|
||||
std::vector<uint8_t> const& pingData,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::vector<uint8_t> pingFrame = EncodeFrame(SocketOpcode::Ping, true, pingData);
|
||||
m_webSocketImplementation->m_receiveStatistics.PingFramesSent++;
|
||||
m_webSocketImplementation->SendTransportBuffer(pingFrame, context);
|
||||
return true;
|
||||
}
|
||||
|
||||
void WebSocketImplementation::SendPong(
|
||||
std::vector<uint8_t> const& pongData,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
std::vector<uint8_t> pongFrame = EncodeFrame(SocketOpcode::Pong, true, pongData);
|
||||
|
||||
m_receiveStatistics.PongFramesSent++;
|
||||
SendTransportBuffer(pongFrame, context);
|
||||
}
|
||||
|
||||
// Generator for random bytes. Used in WebSocketImplementation and tests.
|
||||
std::vector<uint8_t> GenerateRandomBytes(size_t vectorSize)
|
||||
{
|
||||
std::random_device randomEngine;
|
||||
|
||||
std::vector<uint8_t> rv(vectorSize);
|
||||
std::generate(begin(rv), end(rv), [&randomEngine]() mutable {
|
||||
return static_cast<uint8_t>(randomEngine() % UINT8_MAX);
|
||||
});
|
||||
return rv;
|
||||
}
|
||||
}}}}} // namespace Azure::Core::Http::WebSockets::_detail
|
||||
373
sdk/core/azure-core/src/http/websockets/websockets_impl.hpp
Normal file
373
sdk/core/azure-core/src/http/websockets/websockets_impl.hpp
Normal file
@ -0,0 +1,373 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "azure/core/http/websockets/websockets.hpp"
|
||||
#include "azure/core/http/websockets/websockets_transport.hpp"
|
||||
#include "azure/core/internal/diagnostics/log.hpp"
|
||||
#include "azure/core/internal/http/pipeline.hpp"
|
||||
#include <array>
|
||||
#include <condition_variable>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <shared_mutex>
|
||||
#include <thread>
|
||||
|
||||
// Implementation of WebSocket protocol.
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail {
|
||||
|
||||
// Generator for random bytes. Used in WebSocketImplementation and tests.
|
||||
std::vector<uint8_t> GenerateRandomBytes(size_t vectorSize);
|
||||
|
||||
class WebSocketImplementation {
|
||||
enum class SocketState
|
||||
{
|
||||
Invalid,
|
||||
Closed,
|
||||
Opening,
|
||||
Open,
|
||||
Closing,
|
||||
};
|
||||
|
||||
public:
|
||||
WebSocketImplementation(
|
||||
Azure::Core::Url const& remoteUrl,
|
||||
_internal::WebSocketOptions const& options);
|
||||
|
||||
void Open(Azure::Core::Context const& context);
|
||||
void Close(
|
||||
uint16_t closeStatus,
|
||||
std::string const& closeReason,
|
||||
Azure::Core::Context const& context);
|
||||
void SendFrame(
|
||||
std::string const& textFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context);
|
||||
void SendFrame(
|
||||
std::vector<uint8_t> const& binaryFrame,
|
||||
bool isFinalFrame,
|
||||
Azure::Core::Context const& context);
|
||||
|
||||
std::shared_ptr<_internal::WebSocketFrame> ReceiveFrame(Azure::Core::Context const& context);
|
||||
|
||||
void AddHeader(std::string const& headerName, std::string const& headerValue);
|
||||
|
||||
std::string const& GetNegotiatedProtocol();
|
||||
bool IsOpen() { return m_state == SocketState::Open; }
|
||||
bool HasBuiltInWebSocketSupport();
|
||||
|
||||
_internal::WebSocketStatistics GetStatistics() const;
|
||||
|
||||
private:
|
||||
// WebSocket opcodes.
|
||||
enum class SocketOpcode : uint8_t
|
||||
{
|
||||
Continuation = 0x00,
|
||||
TextFrame = 0x01,
|
||||
BinaryFrame = 0x02,
|
||||
Close = 0x08,
|
||||
Ping = 0x09,
|
||||
Pong = 0x0a
|
||||
};
|
||||
|
||||
/**
|
||||
* Indicates the type of the message currently being processed. Used when processing
|
||||
* Continuation Opcode frames.
|
||||
*/
|
||||
enum class SocketMessageType : int
|
||||
{
|
||||
Unknown,
|
||||
Text,
|
||||
Binary,
|
||||
};
|
||||
|
||||
class WebSocketInternalFrame {
|
||||
public:
|
||||
SocketOpcode Opcode{};
|
||||
bool IsFinalFrame{false};
|
||||
std::vector<uint8_t> Payload;
|
||||
std::exception_ptr Exception;
|
||||
WebSocketInternalFrame(
|
||||
SocketOpcode opcode,
|
||||
bool isFinalFrame,
|
||||
std::vector<uint8_t> const& payload)
|
||||
: Opcode(opcode), IsFinalFrame(isFinalFrame), Payload(payload)
|
||||
{
|
||||
}
|
||||
WebSocketInternalFrame(std::exception_ptr exception) : Exception(exception) {}
|
||||
};
|
||||
|
||||
struct ReceiveStatistics
|
||||
{
|
||||
std::atomic<uint32_t> FramesSent;
|
||||
std::atomic<uint32_t> FramesReceived;
|
||||
std::atomic<uint32_t> BytesSent;
|
||||
std::atomic<uint32_t> BytesReceived;
|
||||
std::atomic<uint32_t> PingFramesSent;
|
||||
std::atomic<uint32_t> PingFramesReceived;
|
||||
std::atomic<uint32_t> PongFramesSent;
|
||||
std::atomic<uint32_t> PongFramesReceived;
|
||||
std::atomic<uint32_t> TextFramesReceived;
|
||||
std::atomic<uint32_t> BinaryFramesReceived;
|
||||
std::atomic<uint32_t> ContinuationFramesReceived;
|
||||
std::atomic<uint32_t> CloseFramesReceived;
|
||||
std::atomic<uint32_t> UnknownFramesReceived;
|
||||
std::atomic<uint32_t> FramesDropped;
|
||||
std::atomic<uint32_t> FramesDroppedByPayloadSizeLimit;
|
||||
std::atomic<uint32_t> FramesDroppedByProtocolError;
|
||||
std::atomic<uint32_t> TransportReads;
|
||||
std::atomic<uint32_t> TransportReadBytes;
|
||||
std::atomic<uint32_t> BinaryFramesSent;
|
||||
std::atomic<uint32_t> TextFramesSent;
|
||||
std::atomic<uint32_t> FramesDroppedByClose;
|
||||
|
||||
void Reset()
|
||||
{
|
||||
FramesSent = 0;
|
||||
BytesSent = 0;
|
||||
FramesReceived = 0;
|
||||
BytesReceived = 0;
|
||||
PingFramesReceived = 0;
|
||||
PingFramesSent = 0;
|
||||
PongFramesReceived = 0;
|
||||
PongFramesSent = 0;
|
||||
TextFramesReceived = 0;
|
||||
TextFramesSent = 0;
|
||||
BinaryFramesReceived = 0;
|
||||
BinaryFramesSent = 0;
|
||||
ContinuationFramesReceived = 0;
|
||||
CloseFramesReceived = 0;
|
||||
UnknownFramesReceived = 0;
|
||||
FramesDropped = 0;
|
||||
FramesDroppedByClose = 0;
|
||||
FramesDroppedByPayloadSizeLimit = 0;
|
||||
FramesDroppedByProtocolError = 0;
|
||||
TransportReads = 0;
|
||||
TransportReadBytes = 0;
|
||||
}
|
||||
};
|
||||
/**
|
||||
* @brief The PingThread handles sending Ping operations from the WebSocket server.
|
||||
*
|
||||
*/
|
||||
class PingThread {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ReceiveQueue object.
|
||||
*
|
||||
* @param webSocketImplementation Parent object, used to send Ping threads.
|
||||
* @param pingInterval Interval to wait between sending pings.
|
||||
*/
|
||||
PingThread(
|
||||
WebSocketImplementation* webSocketImplementation,
|
||||
std::chrono::duration<int64_t> pingInterval);
|
||||
/**
|
||||
* @brief Destroys a ReceiveQueue object. Blocks until the queue thread is completed.
|
||||
*/
|
||||
~PingThread();
|
||||
|
||||
/**
|
||||
* @brief Start the receive queue. This will start a thread that will process incoming frames.
|
||||
*
|
||||
* @param transport The websocket transport to use for receiving frames.
|
||||
*/
|
||||
void Start(std::shared_ptr<WebSocketTransport> transport);
|
||||
/**
|
||||
* @brief Stop the receive queue. This will stop the thread that processes incoming frames.
|
||||
*/
|
||||
void Shutdown();
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief The receive queue thread.
|
||||
*/
|
||||
void PingThreadLoop();
|
||||
/**
|
||||
* @brief Send a "ping" frame to the other side of the WebSocket.
|
||||
*
|
||||
* @returns True if the ping was sent, false if the underlying transport didn't support "Ping"
|
||||
* operations.
|
||||
*/
|
||||
bool SendPing(std::vector<uint8_t> const& pingData, Azure::Core::Context const& context);
|
||||
|
||||
WebSocketImplementation* m_webSocketImplementation;
|
||||
std::chrono::duration<int64_t> m_pingInterval;
|
||||
std::thread m_pingThread;
|
||||
std::mutex m_pingThreadStarted;
|
||||
std::condition_variable m_pingThreadReady;
|
||||
|
||||
std::mutex m_stopMutex;
|
||||
std::condition_variable m_pingThreadStopped;
|
||||
bool m_stop = false;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Encode a websocket frame according to RFC 6455 section 5.2.
|
||||
*
|
||||
* This wire format for the data transfer part is described by the ABNF
|
||||
* [RFC5234] given in detail in this section. (Note that, unlike in
|
||||
* other sections of this document, the ABNF in this section is
|
||||
* operating on groups of bits. The length of each group of bits is
|
||||
* indicated in a comment. When encoded on the wire, the most
|
||||
* significant bit is the leftmost in the ABNF). A high-level overview
|
||||
* of the framing is given in the following figure. In a case of
|
||||
* conflict between the figure below and the ABNF specified later in
|
||||
* this section, the figure is authoritative.
|
||||
*
|
||||
* 0 1 2 3
|
||||
* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
* +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
* |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
* |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
* |N|V|V|V| |S| | (if payload len==126/127) |
|
||||
* | |1|2|3| |K| | |
|
||||
* +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
* | Extended payload length continued, if payload len == 127 |
|
||||
* + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
* | |Masking-key, if MASK set to 1 |
|
||||
* +-------------------------------+-------------------------------+
|
||||
* | Masking-key (continued) | Payload Data |
|
||||
* +-------------------------------- - - - - - - - - - - - - - - - +
|
||||
* : Payload Data continued ... :
|
||||
* + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||
* | Payload Data continued ... |
|
||||
* +---------------------------------------------------------------+
|
||||
*
|
||||
* FIN: 1 bit
|
||||
*
|
||||
* Indicates that this is the final fragment in a message. The first
|
||||
* fragment MAY also be the final fragment.
|
||||
*
|
||||
* RSV1, RSV2, RSV3: 1 bit each
|
||||
*
|
||||
* MUST be 0 unless an extension is negotiated that defines meanings
|
||||
* for non-zero values. If a nonzero value is received and none of
|
||||
* the negotiated extensions defines the meaning of such a nonzero
|
||||
* value, the receiving endpoint MUST _Fail the WebSocket
|
||||
* Connection_.
|
||||
*
|
||||
* Opcode: 4 bits
|
||||
*
|
||||
* Defines the interpretation of the "Payload data". If an unknown
|
||||
* opcode is received, the receiving endpoint MUST _Fail the
|
||||
* WebSocket Connection_. The following values are defined.
|
||||
*
|
||||
* * %x0 denotes a continuation frame
|
||||
*
|
||||
* * %x1 denotes a text frame
|
||||
*
|
||||
* * %x2 denotes a binary frame
|
||||
*
|
||||
* * %x3-7 are reserved for further non-control frames
|
||||
*
|
||||
* * %x8 denotes a connection close
|
||||
*
|
||||
* * %x9 denotes a ping
|
||||
*
|
||||
* * %xA denotes a pong
|
||||
*
|
||||
* * %xB-F are reserved for further control frames
|
||||
*
|
||||
* Mask: 1 bit
|
||||
*
|
||||
* Defines whether the "Payload data" is masked. If set to 1, a
|
||||
* masking key is present in masking-key, and this is used to unmask
|
||||
* the "Payload data" as per Section 5.3. All frames sent from
|
||||
* client to server have this bit set to 1.
|
||||
*
|
||||
* Payload length: 7 bits, 7+16 bits, or 7+64 bits
|
||||
*
|
||||
* The length of the "Payload data", in bytes: if 0-125, that is the
|
||||
* payload length. If 126, the following 2 bytes interpreted as a
|
||||
* 16-bit unsigned integer are the payload length. If 127, the
|
||||
* following 8 bytes interpreted as a 64-bit unsigned integer (the
|
||||
* most significant bit MUST be 0) are the payload length. Multibyte
|
||||
* length quantities are expressed in network byte order. Note that
|
||||
* in all cases, the minimal number of bytes MUST be used to encode
|
||||
* the length, for example, the length of a 124-byte-long string
|
||||
* can't be encoded as the sequence 126, 0, 124. The payload length
|
||||
* is the length of the "Extension data" + the length of the
|
||||
* "Application data". The length of the "Extension data" may be
|
||||
* zero, in which case the payload length is the length of the
|
||||
* "Application data".
|
||||
* Masking-key: 0 or 4 bytes
|
||||
*
|
||||
* All frames sent from the client to the server are masked by a
|
||||
* 32-bit value that is contained within the frame. This field is
|
||||
* present if the mask bit is set to 1 and is absent if the mask bit
|
||||
* is set to 0. See Section 5.3 for further information on client-
|
||||
* to-server masking.
|
||||
*
|
||||
* Payload data: (x+y) bytes
|
||||
*
|
||||
* The "Payload data" is defined as "Extension data" concatenated
|
||||
* with "Application data".
|
||||
*
|
||||
* Extension data: x bytes
|
||||
*
|
||||
* The "Extension data" is 0 bytes unless an extension has been
|
||||
* negotiated. Any extension MUST specify the length of the
|
||||
* "Extension data", or how that length may be calculated, and how
|
||||
* the extension use MUST be negotiated during the opening handshake.
|
||||
* If present, the "Extension data" is included in the total payload
|
||||
* length.
|
||||
*
|
||||
* Application data: y bytes
|
||||
*
|
||||
* Arbitrary "Application data", taking up the remainder of the frame
|
||||
* after any "Extension data". The length of the "Application data"
|
||||
* is equal to the payload length minus the length of the "Extension
|
||||
* data".
|
||||
*/
|
||||
static std::vector<uint8_t> EncodeFrame(
|
||||
SocketOpcode opcode,
|
||||
bool isFinal,
|
||||
std::vector<uint8_t> const& payload);
|
||||
|
||||
SocketState m_state{SocketState::Invalid};
|
||||
|
||||
std::vector<uint8_t> GenerateRandomKey() { return GenerateRandomBytes(16); };
|
||||
void VerifySocketAccept(std::string const& encodedKey, std::string const& acceptHeader);
|
||||
|
||||
/*********
|
||||
* Buffered Read Support. Read data from the underlying transport into a buffer.
|
||||
*/
|
||||
uint8_t ReadTransportByte(Azure::Core::Context const& context);
|
||||
uint16_t ReadTransportShort(Azure::Core::Context const& context);
|
||||
uint64_t ReadTransportInt64(Azure::Core::Context const& context);
|
||||
std::vector<uint8_t> ReadTransportBytes(size_t readLength, Azure::Core::Context const& context);
|
||||
bool IsTransportEof() const { return m_eof; }
|
||||
void SendPong(std::vector<uint8_t> const& pongData, Azure::Core::Context const& context);
|
||||
void SendTransportBuffer(
|
||||
std::vector<uint8_t> const& payload,
|
||||
Azure::Core::Context const& context);
|
||||
std::shared_ptr<WebSocketInternalFrame> ReceiveTransportFrame(
|
||||
Azure::Core::Context const& context);
|
||||
|
||||
/**
|
||||
* @brief Decode a frame received from the websocket server.
|
||||
*
|
||||
* @returns A pointer to the start of the decoded data.
|
||||
*/
|
||||
std::shared_ptr<WebSocketInternalFrame> DecodeFrame(Azure::Core::Context const& context);
|
||||
|
||||
Azure::Core::Url m_remoteUrl;
|
||||
_internal::WebSocketOptions m_options;
|
||||
std::map<std::string, std::string> m_headers;
|
||||
std::string m_chosenProtocol;
|
||||
std::shared_ptr<Azure::Core::Http::WebSockets::WebSocketTransport> m_transport;
|
||||
PingThread m_pingThread;
|
||||
SocketMessageType m_currentMessageType{SocketMessageType::Unknown};
|
||||
std::mutex m_stateMutex;
|
||||
std::thread::id m_stateOwner;
|
||||
|
||||
ReceiveStatistics m_receiveStatistics{};
|
||||
|
||||
std::mutex m_transportMutex;
|
||||
|
||||
std::unique_ptr<Azure::Core::IO::BodyStream> m_initialBodyStream;
|
||||
constexpr static size_t m_bufferSize = 1024;
|
||||
uint8_t m_buffer[m_bufferSize]{};
|
||||
size_t m_bufferPos = 0;
|
||||
size_t m_bufferLen = 0;
|
||||
bool m_eof = false;
|
||||
};
|
||||
}}}}} // namespace Azure::Core::Http::WebSockets::_detail
|
||||
221
sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp
Normal file
221
sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp
Normal file
@ -0,0 +1,221 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "azure/core/http/http.hpp"
|
||||
#include "azure/core/http/policies/policy.hpp"
|
||||
#include "azure/core/http/transport.hpp"
|
||||
#include "azure/core/http/websockets/win_http_websockets_transport.hpp"
|
||||
#include "azure/core/internal/diagnostics/log.hpp"
|
||||
#include "azure/core/platform.hpp"
|
||||
|
||||
#if defined(AZ_PLATFORM_POSIX)
|
||||
#include <poll.h> // for poll()
|
||||
#include <sys/socket.h> // for socket shutdown
|
||||
#elif defined(AZ_PLATFORM_WINDOWS)
|
||||
#if !defined(WIN32_LEAN_AND_MEAN)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#endif
|
||||
#if !defined(NOMINMAX)
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <winapifamily.h>
|
||||
#include <winsock2.h> // for WSAPoll();
|
||||
#endif
|
||||
#include <shared_mutex>
|
||||
|
||||
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
|
||||
|
||||
void WinHttpWebSocketTransport::OnUpgradedConnection(
|
||||
Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle)
|
||||
{
|
||||
// Convert the request handle into a WebSocket handle for us to use later.
|
||||
m_socketHandle = Azure::Core::Http::_detail::unique_HINTERNET(
|
||||
WinHttpWebSocketCompleteUpgrade(requestHandle.get(), 0),
|
||||
Azure::Core::Http::_detail::HINTERNET_deleter{});
|
||||
if (!m_socketHandle)
|
||||
{
|
||||
GetErrorAndThrow("Error Upgrading HttpRequest handle to WebSocket handle.");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Azure::Core::Http::RawResponse> WinHttpWebSocketTransport::Send(
|
||||
Azure::Core::Http::Request& request,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
return WinHttpTransport::Send(request, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Close the WebSocket cleanly.
|
||||
*/
|
||||
void WinHttpWebSocketTransport::Close() { m_socketHandle.reset(); }
|
||||
|
||||
// Native WebSocket support methods.
|
||||
/**
|
||||
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native websockets.
|
||||
*
|
||||
* @param status Status value to be sent to the remote node. Application defined.
|
||||
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
*/
|
||||
void WinHttpWebSocketTransport::NativeCloseSocket(
|
||||
uint16_t status,
|
||||
std::string const& disconnectReason,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
context.ThrowIfCancelled();
|
||||
|
||||
auto err = WinHttpWebSocketClose(
|
||||
m_socketHandle.get(),
|
||||
status,
|
||||
disconnectReason.empty()
|
||||
? nullptr
|
||||
: reinterpret_cast<PVOID>(const_cast<char*>(disconnectReason.c_str())),
|
||||
static_cast<DWORD>(disconnectReason.size()));
|
||||
if (err != 0)
|
||||
{
|
||||
GetErrorAndThrow("WinHttpWebSocketClose() failed", err);
|
||||
}
|
||||
|
||||
context.ThrowIfCancelled();
|
||||
|
||||
// Make sure that the server responds gracefully to the close request.
|
||||
auto closeInformation = NativeGetCloseSocketInformation(context);
|
||||
|
||||
// The server should return the same status we sent.
|
||||
if (closeInformation.CloseReason != status)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Close status mismatch, got " + std::to_string(closeInformation.CloseReason)
|
||||
+ " expected " + std::to_string(status));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Retrieve the information associated with a WebSocket close response.
|
||||
*
|
||||
* Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType
|
||||
*
|
||||
* @param context Context for the operation.
|
||||
*
|
||||
* @returns a tuple containing the status code and string.
|
||||
*/
|
||||
WinHttpWebSocketTransport::NativeWebSocketCloseInformation
|
||||
WinHttpWebSocketTransport::NativeGetCloseSocketInformation(Azure::Core::Context const& context)
|
||||
{
|
||||
context.ThrowIfCancelled();
|
||||
uint16_t closeStatus = 0;
|
||||
char closeReason[WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH]{};
|
||||
DWORD closeReasonLength;
|
||||
|
||||
auto err = WinHttpWebSocketQueryCloseStatus(
|
||||
m_socketHandle.get(),
|
||||
&closeStatus,
|
||||
closeReason,
|
||||
WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH,
|
||||
&closeReasonLength);
|
||||
if (err != 0)
|
||||
{
|
||||
GetErrorAndThrow("WinHttpGetCloseStatus() failed", err);
|
||||
}
|
||||
return NativeWebSocketCloseInformation{closeStatus, std::string(closeReason)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Send a frame of data to the remote node.
|
||||
*
|
||||
* @details Not implemented for CURL websockets because CURL does not support native
|
||||
* websockets.
|
||||
*
|
||||
* @brief frameType Frame type sent to the server, Text or Binary.
|
||||
* @brief frameData Frame data to be sent to the server.
|
||||
*/
|
||||
void WinHttpWebSocketTransport::NativeSendFrame(
|
||||
NativeWebSocketFrameType frameType,
|
||||
std::vector<uint8_t> const& frameData,
|
||||
Azure::Core::Context const& context)
|
||||
{
|
||||
context.ThrowIfCancelled();
|
||||
WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType;
|
||||
switch (frameType)
|
||||
{
|
||||
case NativeWebSocketFrameType::Text:
|
||||
bufferType = WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE;
|
||||
break;
|
||||
case NativeWebSocketFrameType::Binary:
|
||||
bufferType = WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE;
|
||||
break;
|
||||
case NativeWebSocketFrameType::BinaryFragment:
|
||||
bufferType = WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE;
|
||||
break;
|
||||
case NativeWebSocketFrameType::TextFragment:
|
||||
bufferType = WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unknown frame type: " + std::to_string(static_cast<uint32_t>(frameType)));
|
||||
break;
|
||||
}
|
||||
// Lock the socket to prevent concurrent writes. WinHTTP gets annoyed if
|
||||
// there are multiple WinHttpWebSocketSend requests outstanding.
|
||||
std::lock_guard<std::mutex> lock(m_sendMutex);
|
||||
auto err = WinHttpWebSocketSend(
|
||||
m_socketHandle.get(),
|
||||
bufferType,
|
||||
reinterpret_cast<PVOID>(const_cast<uint8_t*>(frameData.data())),
|
||||
static_cast<DWORD>(frameData.size()));
|
||||
if (err != 0)
|
||||
{
|
||||
GetErrorAndThrow("WinHttpWebSocketSend() failed", err);
|
||||
}
|
||||
}
|
||||
|
||||
WinHttpWebSocketTransport::NativeWebSocketReceiveInformation
|
||||
WinHttpWebSocketTransport::NativeReceiveFrame(Azure::Core::Context const& context)
|
||||
{
|
||||
WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType;
|
||||
NativeWebSocketFrameType frameTypeReceived;
|
||||
DWORD bufferBytesRead;
|
||||
std::vector<uint8_t> buffer(128);
|
||||
context.ThrowIfCancelled();
|
||||
std::lock_guard<std::mutex> lock(m_receiveMutex);
|
||||
|
||||
auto err = WinHttpWebSocketReceive(
|
||||
m_socketHandle.get(),
|
||||
reinterpret_cast<PVOID>(buffer.data()),
|
||||
static_cast<DWORD>(buffer.size()),
|
||||
&bufferBytesRead,
|
||||
&bufferType);
|
||||
if (err != 0 && err != ERROR_INSUFFICIENT_BUFFER)
|
||||
{
|
||||
GetErrorAndThrow("WinHttpWebSocketReceive() failed", err);
|
||||
}
|
||||
buffer.resize(bufferBytesRead);
|
||||
|
||||
switch (bufferType)
|
||||
{
|
||||
case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE:
|
||||
frameTypeReceived = NativeWebSocketFrameType::Text;
|
||||
break;
|
||||
case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE:
|
||||
frameTypeReceived = NativeWebSocketFrameType::Binary;
|
||||
break;
|
||||
case WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE:
|
||||
frameTypeReceived = NativeWebSocketFrameType::BinaryFragment;
|
||||
break;
|
||||
case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE:
|
||||
frameTypeReceived = NativeWebSocketFrameType::TextFragment;
|
||||
break;
|
||||
case WINHTTP_WEB_SOCKET_CLOSE_BUFFER_TYPE:
|
||||
frameTypeReceived = NativeWebSocketFrameType::Closed;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unknown frame type: " + std::to_string(bufferType));
|
||||
break;
|
||||
}
|
||||
return NativeWebSocketReceiveInformation{frameTypeReceived, buffer};
|
||||
}
|
||||
|
||||
}}}} // namespace Azure::Core::Http::WebSockets
|
||||
@ -1,5 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# SPDX-License-Identifier: MIT
|
||||
param(
|
||||
[string] $LogFileLocation = "$($env:BUILD_SOURCESDIRECTORY)/WebSocketServer.log"
|
||||
)
|
||||
|
||||
155
sdk/core/azure-core/test/ut/websocket_server.py
Normal file
155
sdk/core/azure-core/test/ut/websocket_server.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# SPDX-License-Identifier: MIT
|
||||
from array import array
|
||||
import asyncio
|
||||
from operator import length_hint
|
||||
import threading
|
||||
from time import sleep
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import websockets
|
||||
|
||||
# create handler for each connection
|
||||
customPaths = {}
|
||||
stop = False
|
||||
|
||||
async def handleControlPath(websocket):
|
||||
while (1):
|
||||
data : str = await websocket.recv()
|
||||
parsedCommand = data.split(' ')
|
||||
if (parsedCommand[0] == "close"):
|
||||
print("Closing control channel")
|
||||
await websocket.send("ok")
|
||||
print("Terminating WebSocket server.")
|
||||
stop.set_result(0)
|
||||
break
|
||||
elif parsedCommand[0] == "newPath":
|
||||
print("Add path")
|
||||
newPath = parsedCommand[1]
|
||||
print(" Add path ", newPath)
|
||||
customPaths[newPath] = {"path": newPath, "delay": int(parsedCommand[2]) }
|
||||
await websocket.send("ok")
|
||||
else:
|
||||
print("Unknown command, echoing it.")
|
||||
await websocket.send(data)
|
||||
|
||||
async def handleCustomPath(websocket, path:dict):
|
||||
print("Handle custom path", path)
|
||||
data : str = await websocket.recv()
|
||||
print("Received ", data)
|
||||
if ("delay" in path.keys()):
|
||||
sleep(path["delay"])
|
||||
print("Responding")
|
||||
await websocket.send(data)
|
||||
await websocket.close()
|
||||
|
||||
def HexEncode(data: bytes)->str:
|
||||
rv=""
|
||||
for val in data:
|
||||
rv+= '{:02X}'.format(val)
|
||||
return rv
|
||||
|
||||
def ParseQuery(url : ParseResult) -> dict:
|
||||
rv={}
|
||||
if len(url.query)!=0:
|
||||
args = url.query.split('&')
|
||||
for arg in args:
|
||||
vals=arg.split('=')
|
||||
rv[vals[0]]=vals[1]
|
||||
return rv
|
||||
|
||||
echo_count_lock = threading.Lock()
|
||||
echo_count_recv = 0
|
||||
echo_count_send = 0
|
||||
client_count = 0
|
||||
async def handleEcho(websocket, url:ParseResult):
|
||||
global client_count
|
||||
global echo_count_recv
|
||||
global echo_count_send
|
||||
global echo_count_lock
|
||||
queryValues = ParseQuery(url)
|
||||
while websocket.open:
|
||||
try:
|
||||
data = await websocket.recv()
|
||||
with echo_count_lock:
|
||||
echo_count_recv+=1
|
||||
if 'delay' in queryValues:
|
||||
print(f"sleeping for {queryValues['delay']} seconds")
|
||||
await asyncio.sleep(float(queryValues['delay']))
|
||||
print("woken up.")
|
||||
|
||||
if 'fragment' in queryValues and queryValues['fragment']=='true':
|
||||
await websocket.send(data.split())
|
||||
else:
|
||||
await websocket.send(data)
|
||||
with echo_count_lock:
|
||||
echo_count_send+=1
|
||||
except websockets.ConnectionClosedOK:
|
||||
print("Connection closed ok.")
|
||||
with echo_count_lock:
|
||||
client_count -= 1
|
||||
print(f"Echo count: {echo_count_recv}, {echo_count_send} client_count {client_count}")
|
||||
if client_count == 0:
|
||||
echo_count_send = 0
|
||||
echo_count_recv = 0
|
||||
return
|
||||
except websockets.ConnectionClosed as ex:
|
||||
if (ex.rcvd):
|
||||
print(f"Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}")
|
||||
else:
|
||||
print(f"Connection closed. No close information.")
|
||||
with echo_count_lock:
|
||||
client_count -= 1
|
||||
print(f"Echo count: recv: {echo_count_recv}, send: {echo_count_send} client_count {client_count}")
|
||||
if client_count == 0:
|
||||
echo_count_send = 0
|
||||
echo_count_recv = 0
|
||||
return
|
||||
|
||||
async def handler(websocket, path : str):
|
||||
global client_count
|
||||
print("Socket handler: ", path)
|
||||
parsedUrl = urlparse(path)
|
||||
if (parsedUrl.path == '/openclosetest'):
|
||||
print("Open/Close Test")
|
||||
try:
|
||||
data = await websocket.recv()
|
||||
print(f"OpenCloseTest: Received {data}")
|
||||
except websockets.ConnectionClosedOK:
|
||||
print("OpenCloseTest: Connection closed ok.")
|
||||
except websockets.ConnectionClosed as ex:
|
||||
print(f"OpenCloseTest: Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}")
|
||||
return
|
||||
elif (parsedUrl.path == '/echotest'):
|
||||
with echo_count_lock:
|
||||
client_count+= 1
|
||||
await handleEcho(websocket, parsedUrl)
|
||||
elif (parsedUrl.path == '/closeduringecho'):
|
||||
data = await websocket.recv()
|
||||
await websocket.close(1001, 'closed')
|
||||
elif (parsedUrl.path =='/control'):
|
||||
await handleControlPath(websocket)
|
||||
elif (parsedUrl.path in customPaths.keys()):
|
||||
print("Found path ", path, "in control paths.")
|
||||
await handleCustomPath(websocket, customPaths[path])
|
||||
elif (parsedUrl.path == '/terminateserver'):
|
||||
print("Terminating WebSocket server.")
|
||||
stop.set_result(0)
|
||||
else:
|
||||
data = await websocket.recv()
|
||||
print("Received: ", data)
|
||||
|
||||
reply = f"Data received as: {data}!"
|
||||
await websocket.send(reply)
|
||||
|
||||
async def main():
|
||||
global stop
|
||||
print("Starting server")
|
||||
loop = asyncio.get_running_loop()
|
||||
stop = loop.create_future()
|
||||
async with websockets.serve(handler, "localhost", 8000, ping_interval=7):
|
||||
await stop # run forever.
|
||||
|
||||
if __name__=="__main__":
|
||||
asyncio.run(main())
|
||||
print("Ending server")
|
||||
875
sdk/core/azure-core/test/ut/websocket_test.cpp
Normal file
875
sdk/core/azure-core/test/ut/websocket_test.cpp
Normal file
@ -0,0 +1,875 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "../../src/http/websockets/websockets_impl.hpp"
|
||||
#include "azure/core/http/websockets/websockets.hpp"
|
||||
#include "azure/core/internal/json/json.hpp"
|
||||
#include <chrono>
|
||||
#include <gtest/gtest.h>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
|
||||
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
|
||||
#endif
|
||||
// cspell::words closeme flibbityflobbidy
|
||||
|
||||
using namespace Azure::Core;
|
||||
using namespace Azure::Core::Http::WebSockets;
|
||||
using namespace Azure::Core::Http::WebSockets::_internal;
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
constexpr uint16_t UndefinedButLegalCloseReason = 4500;
|
||||
|
||||
class WebSocketTests : public testing::Test {
|
||||
private:
|
||||
protected:
|
||||
// Create
|
||||
static void SetUpTestSuite() {}
|
||||
static void TearDownTestSuite() {}
|
||||
};
|
||||
|
||||
TEST_F(WebSocketTests, CreateSimpleSocket)
|
||||
{
|
||||
{
|
||||
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000"));
|
||||
defaultSocket.AddHeader("newHeader", "headerValue");
|
||||
EXPECT_THROW(defaultSocket.GetNegotiatedProtocol(), std::runtime_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, OpenSimpleSocket)
|
||||
{
|
||||
{
|
||||
WebSocketOptions options;
|
||||
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"), options);
|
||||
defaultSocket.AddHeader("newHeader", "headerValue");
|
||||
|
||||
defaultSocket.Open();
|
||||
|
||||
EXPECT_THROW(defaultSocket.AddHeader("newHeader", "headerValue"), std::runtime_error);
|
||||
|
||||
// Close the socket without notifying the peer.
|
||||
defaultSocket.Close();
|
||||
}
|
||||
|
||||
{
|
||||
WebSocketOptions options;
|
||||
WebSocket defaultSocket(Azure::Core::Url("http://www.microsoft.com/"), options);
|
||||
defaultSocket.AddHeader("newHeader", "headerValue");
|
||||
|
||||
// When running this test locally, the call times out, so drop in a 5 second timeout on
|
||||
// the request.
|
||||
Azure::Core::Context requestContext = Azure::Core::Context::ApplicationContext.WithDeadline(
|
||||
std::chrono::system_clock::now() + 5s);
|
||||
EXPECT_THROW(defaultSocket.Open(requestContext), std::runtime_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, OpenAndCloseSocket)
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"));
|
||||
defaultSocket.AddHeader("newHeader", "headerValue");
|
||||
|
||||
defaultSocket.Open();
|
||||
|
||||
// Close the socket without notifying the peer.
|
||||
defaultSocket.Close(UndefinedButLegalCloseReason);
|
||||
}
|
||||
|
||||
{
|
||||
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"));
|
||||
|
||||
defaultSocket.Open();
|
||||
|
||||
// Close the socket without notifying the peer.
|
||||
defaultSocket.Close(UndefinedButLegalCloseReason, "This is a good reason.");
|
||||
|
||||
//
|
||||
// Now re-open the socket - this should work to reset everything.
|
||||
defaultSocket.Open();
|
||||
EXPECT_THROW(defaultSocket.Open(), std::runtime_error);
|
||||
defaultSocket.Close();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, SimpleEcho)
|
||||
{
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
testSocket.SendFrame("Test message", true);
|
||||
|
||||
auto response = testSocket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::TextFrameReceived, response->FrameType);
|
||||
EXPECT_THROW(response->AsBinaryFrame(), std::logic_error);
|
||||
auto textResult = response->AsTextFrame();
|
||||
EXPECT_EQ("Test message", textResult->Text);
|
||||
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?delay=5"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
std::vector<uint8_t> binaryData{1, 2, 3, 4, 5, 6};
|
||||
|
||||
testSocket.SendFrame(binaryData, true);
|
||||
|
||||
auto response = testSocket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
EXPECT_THROW(response->AsPeerCloseFrame(), std::logic_error);
|
||||
EXPECT_THROW(response->AsTextFrame(), std::logic_error);
|
||||
auto textResult = response->AsBinaryFrame();
|
||||
EXPECT_EQ(binaryData, textResult->Data);
|
||||
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?fragment=true&delay=5"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
std::vector<uint8_t> binaryData{1, 2, 3, 4, 5, 6};
|
||||
|
||||
testSocket.SendFrame(binaryData, true);
|
||||
|
||||
std::vector<uint8_t> responseData;
|
||||
std::shared_ptr<WebSocketFrame> response;
|
||||
do
|
||||
{
|
||||
response = testSocket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
auto binaryResult = response->AsBinaryFrame();
|
||||
responseData.insert(responseData.end(), binaryResult->Data.begin(), binaryResult->Data.end());
|
||||
} while (!response->IsFinalFrame);
|
||||
|
||||
auto textResult = response->AsBinaryFrame();
|
||||
EXPECT_EQ(binaryData, responseData);
|
||||
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t N> void EchoRandomData(WebSocket& socket)
|
||||
{
|
||||
std::vector<uint8_t> sendData = Azure::Core::Http::WebSockets::_detail::GenerateRandomBytes(N);
|
||||
|
||||
socket.SendFrame(sendData, true);
|
||||
|
||||
std::vector<uint8_t> receiveData;
|
||||
|
||||
std::shared_ptr<WebSocketFrame> response;
|
||||
do
|
||||
{
|
||||
response = socket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
auto binaryResult = response->AsBinaryFrame();
|
||||
receiveData.insert(receiveData.end(), binaryResult->Data.begin(), binaryResult->Data.end());
|
||||
} while (!response->IsFinalFrame);
|
||||
|
||||
// Make sure we get back the data we sent in the echo request.
|
||||
EXPECT_EQ(sendData.size(), receiveData.size());
|
||||
EXPECT_EQ(sendData, receiveData);
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, VariableSizeEcho)
|
||||
{
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
|
||||
|
||||
testSocket.Open();
|
||||
{
|
||||
EchoRandomData<100>(testSocket);
|
||||
EchoRandomData<124>(testSocket);
|
||||
EchoRandomData<125>(testSocket);
|
||||
// The websocket protocol treats lengths of 125, 126 and > 127 specially.
|
||||
EchoRandomData<126>(testSocket);
|
||||
EchoRandomData<127>(testSocket);
|
||||
EchoRandomData<128>(testSocket);
|
||||
EchoRandomData<1020>(testSocket); // 1K-4
|
||||
EchoRandomData<1021>(testSocket); // 1K-3
|
||||
EchoRandomData<1022>(testSocket); // 1K-2
|
||||
EchoRandomData<1023>(testSocket); // 1K-1
|
||||
EchoRandomData<1024>(testSocket); // 1K
|
||||
EchoRandomData<2048>(testSocket); // 2K
|
||||
EchoRandomData<4096>(testSocket); // 4K
|
||||
EchoRandomData<8192>(testSocket); // 8K
|
||||
// The websocket protocol treats lengths of >65536 specially.
|
||||
EchoRandomData<65535>(testSocket); // 64K-1
|
||||
EchoRandomData<65536>(testSocket); // 64K
|
||||
EchoRandomData<65537>(testSocket); // 64K+1
|
||||
EchoRandomData<131072>(testSocket); // 128K
|
||||
}
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
}
|
||||
|
||||
// Generator for random bytes. Used in WebSocketImplementation and tests.
|
||||
std::vector<uint8_t> GenerateRandomBytes(size_t index, size_t vectorSize)
|
||||
{
|
||||
std::random_device randomEngine;
|
||||
|
||||
std::vector<uint8_t> rv(vectorSize + 4);
|
||||
rv[0] = index & 0xff;
|
||||
rv[1] = (index >> 8) & 0xff;
|
||||
rv[2] = (index >> 16) & 0xff;
|
||||
rv[3] = (index >> 24) & 0xff;
|
||||
std::generate(std::begin(rv) + 4, std::end(rv), [&randomEngine]() mutable {
|
||||
return static_cast<uint8_t>(randomEngine() % UINT8_MAX);
|
||||
});
|
||||
return rv;
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, CloseDuringEcho)
|
||||
{
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
testSocket.SendFrame("Test message", true);
|
||||
|
||||
auto response = testSocket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::PeerClosedReceived, response->FrameType);
|
||||
auto PeerClosedReceived = response->AsPeerCloseFrame();
|
||||
EXPECT_EQ(1001, PeerClosedReceived->RemoteStatusCode);
|
||||
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
|
||||
// Close the websocket while a thread is waiting for a response.
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/echotest?delay=10"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
std::thread testThread([&]() {
|
||||
try
|
||||
{
|
||||
std::vector<uint8_t> sendData = GenerateRandomBytes(0, 100);
|
||||
testSocket.SendFrame(sendData);
|
||||
GTEST_LOG_(INFO) << "Receive frame.";
|
||||
auto response = testSocket.ReceiveFrame();
|
||||
GTEST_LOG_(INFO) << "Received frame.";
|
||||
if (response->FrameType == WebSocketFrameType::PeerClosedReceived)
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Peer closed the socket; Terminating thread.";
|
||||
return;
|
||||
}
|
||||
else if (response->FrameType != WebSocketFrameType::BinaryFrameReceived)
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Unexpected frame type received.";
|
||||
}
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
auto binaryResult = response->AsBinaryFrame();
|
||||
}
|
||||
catch (Azure::Core::OperationCancelledException& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what()
|
||||
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
|
||||
}
|
||||
catch (std::exception const& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
|
||||
}
|
||||
});
|
||||
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
// Close the socket gracefully.
|
||||
GTEST_LOG_(INFO) << "Closing Socket.";
|
||||
EXPECT_NO_THROW(testSocket.Close(UndefinedButLegalCloseReason, "Close Reason."));
|
||||
GTEST_LOG_(INFO) << "Closed Socket.";
|
||||
testThread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, ExpectThrow)
|
||||
{
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho"));
|
||||
|
||||
EXPECT_THROW(testSocket.SendFrame("Foo", true), std::runtime_error);
|
||||
std::vector<uint8_t> data{1, 2, 3, 4};
|
||||
EXPECT_THROW(testSocket.SendFrame(data, true), std::runtime_error);
|
||||
EXPECT_THROW(testSocket.ReceiveFrame(), std::runtime_error);
|
||||
}
|
||||
}
|
||||
|
||||
std::string ToHexString(std::vector<uint8_t> const& data)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for (auto const& byte : data)
|
||||
{
|
||||
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(byte);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, PingReceiveTest)
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
|
||||
|
||||
testSocket.Open();
|
||||
if (!testSocket.HasBuiltInWebSocketSupport())
|
||||
{
|
||||
|
||||
GTEST_LOG_(INFO) << "Sleeping for 15 seconds to collect pings.";
|
||||
Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline(
|
||||
Azure::DateTime{std::chrono::system_clock::now() + 15s});
|
||||
EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException);
|
||||
auto statistics = testSocket.GetStatistics();
|
||||
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
|
||||
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
|
||||
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
|
||||
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
|
||||
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
|
||||
EXPECT_NE(0, statistics.PingFramesReceived);
|
||||
EXPECT_NE(0, statistics.PongFramesSent);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, PingSendTest)
|
||||
{
|
||||
// Configure the socket to ping every second.
|
||||
WebSocketOptions socketOptions;
|
||||
socketOptions.PingInterval = std::chrono::seconds(1);
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"), socketOptions);
|
||||
|
||||
testSocket.Open();
|
||||
if (!testSocket.HasBuiltInWebSocketSupport())
|
||||
{
|
||||
|
||||
GTEST_LOG_(INFO) << "Sleeping for 10 seconds to collect pings.";
|
||||
// Note that we cannot collect incoming pings or outgoing pongs unless we are receiving
|
||||
// data from the server.
|
||||
Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline(
|
||||
Azure::DateTime{std::chrono::system_clock::now() + 10s});
|
||||
EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException);
|
||||
auto statistics = testSocket.GetStatistics();
|
||||
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
|
||||
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
|
||||
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
|
||||
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
|
||||
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
|
||||
EXPECT_NE(0, statistics.PingFramesSent);
|
||||
EXPECT_NE(0, statistics.PongFramesReceived);
|
||||
EXPECT_NE(0, statistics.PingFramesReceived);
|
||||
EXPECT_NE(0, statistics.PongFramesSent);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, MultiThreadedTestOnSingleSocket)
|
||||
{
|
||||
constexpr size_t threadCount = 50;
|
||||
constexpr size_t testDataLength = 200000;
|
||||
constexpr size_t testDataSize = 100;
|
||||
constexpr auto testDuration = 10s;
|
||||
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
// seed test data for the operations.
|
||||
std::vector<std::vector<uint8_t>> testData(testDataLength);
|
||||
std::vector<std::vector<uint8_t>> receivedData(testDataLength);
|
||||
std::atomic_size_t iterationCount(0);
|
||||
|
||||
// Spin up threadCount threads and hammer the echo server for 10 seconds.
|
||||
std::vector<std::thread> threads;
|
||||
std::atomic_int32_t cancellationExceptions{0};
|
||||
std::atomic_int32_t exceptions{0};
|
||||
for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1)
|
||||
{
|
||||
threads.push_back(std::thread([&]() {
|
||||
std::chrono::time_point<std::chrono::system_clock> startTime
|
||||
= std::chrono::system_clock::now();
|
||||
// Set the context to expire *after* the test is supposed to finish.
|
||||
Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline(
|
||||
Azure::DateTime{startTime} + testDuration + 10s);
|
||||
size_t iteration = 0;
|
||||
try
|
||||
{
|
||||
do
|
||||
{
|
||||
iteration = iterationCount++;
|
||||
std::vector<uint8_t> sendData = GenerateRandomBytes(iteration, testDataSize);
|
||||
{
|
||||
if (iteration < testData.size())
|
||||
{
|
||||
if (testData[iteration].size() != 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl;
|
||||
}
|
||||
EXPECT_EQ(0, testData[iteration].size());
|
||||
testData[iteration] = sendData;
|
||||
}
|
||||
}
|
||||
|
||||
testSocket.SendFrame(sendData, true /*, context*/);
|
||||
auto response = testSocket.ReceiveFrame(context);
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
auto binaryResult = response->AsBinaryFrame();
|
||||
|
||||
// Make sure we get back the data we sent in the echo request.
|
||||
if (binaryResult->Data.size() == 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl;
|
||||
}
|
||||
EXPECT_EQ(sendData.size(), binaryResult->Data.size());
|
||||
{
|
||||
// There is no ordering expectation on the results, so we just remember the data
|
||||
// as it comes in. We'll make sure we received everything later on.
|
||||
if (iteration < receivedData.size())
|
||||
{
|
||||
if (receivedData[iteration].size() != 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
EXPECT_EQ(0, receivedData[iteration].size());
|
||||
receivedData[iteration] = binaryResult->Data;
|
||||
}
|
||||
}
|
||||
} while (std::chrono::system_clock::now() - startTime < testDuration);
|
||||
}
|
||||
catch (Azure::Core::OperationCancelledException& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration
|
||||
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
|
||||
cancellationExceptions++;
|
||||
}
|
||||
catch (std::exception const& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
|
||||
exceptions++;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait for all the threads to exit.
|
||||
for (auto& thread : threads)
|
||||
{
|
||||
thread.join();
|
||||
}
|
||||
|
||||
// We no longer need to worry about synchronization since all the worker threads are done.
|
||||
GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl;
|
||||
GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex
|
||||
<< testData.size() << ")" << std::endl;
|
||||
EXPECT_GE(testDataLength, iterationCount.load());
|
||||
|
||||
auto statistics = testSocket.GetStatistics();
|
||||
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
|
||||
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
|
||||
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
|
||||
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
|
||||
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
|
||||
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
|
||||
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
|
||||
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
|
||||
EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesSent);
|
||||
EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesReceived);
|
||||
|
||||
// Resize the test data to the number of actual iterations.
|
||||
testData.resize(iterationCount.load());
|
||||
receivedData.resize(iterationCount.load());
|
||||
|
||||
// If we've processed every iteration, let's make sure that we received everything we sent.
|
||||
// If we dropped some results, then we can't check to ensure that we have received everything
|
||||
// because we can't account for everything sent.
|
||||
std::multiset<std::string> testDataStrings;
|
||||
std::multiset<std::string> receivedDataStrings;
|
||||
for (auto const& data : testData)
|
||||
{
|
||||
testDataStrings.emplace(ToHexString(data));
|
||||
}
|
||||
for (auto const& data : receivedData)
|
||||
{
|
||||
receivedDataStrings.emplace(ToHexString(data));
|
||||
}
|
||||
|
||||
EXPECT_EQ(testDataStrings, receivedDataStrings);
|
||||
for (auto const& data : testDataStrings)
|
||||
{
|
||||
if (receivedDataStrings.count(data) != testDataStrings.count(data))
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data)
|
||||
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
|
||||
<< " Missing Data: " << data << std::endl;
|
||||
}
|
||||
EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data));
|
||||
}
|
||||
for (auto const& data : receivedDataStrings)
|
||||
{
|
||||
if (testDataStrings.count(data) != receivedDataStrings.count(data))
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data)
|
||||
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
|
||||
<< " Missing Data: " << data << std::endl;
|
||||
}
|
||||
|
||||
EXPECT_NE(testDataStrings.end(), testDataStrings.find(data));
|
||||
}
|
||||
|
||||
// We shouldn't have seen any exceptions during the run.
|
||||
EXPECT_EQ(0, exceptions.load());
|
||||
EXPECT_EQ(0, cancellationExceptions.load());
|
||||
}
|
||||
|
||||
TEST_F(WebSocketTests, MultiThreadedTestOnMultipleSockets)
|
||||
{
|
||||
constexpr size_t threadCount = 50;
|
||||
constexpr size_t testDataLength = 200000;
|
||||
constexpr size_t testDataSize = 100;
|
||||
constexpr auto testDuration = 10s;
|
||||
|
||||
// seed test data for the operations.
|
||||
std::vector<std::vector<uint8_t>> testData(testDataLength);
|
||||
std::vector<std::vector<uint8_t>> receivedData(testDataLength);
|
||||
std::atomic_size_t iterationCount(0);
|
||||
|
||||
// Spin up threadCount threads and hammer the echo server for 10 seconds.
|
||||
std::vector<std::thread> threads;
|
||||
std::atomic_int32_t cancellationExceptions{0};
|
||||
std::atomic_int32_t exceptions{0};
|
||||
for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1)
|
||||
{
|
||||
threads.push_back(std::thread([&]() {
|
||||
std::chrono::time_point<std::chrono::system_clock> startTime
|
||||
= std::chrono::system_clock::now();
|
||||
// Set the context to expire *after* the test is supposed to finish.
|
||||
Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline(
|
||||
Azure::DateTime{startTime} + testDuration + 10s);
|
||||
size_t iteration = 0;
|
||||
try
|
||||
{
|
||||
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
|
||||
|
||||
testSocket.Open();
|
||||
|
||||
do
|
||||
{
|
||||
iteration = iterationCount++;
|
||||
std::vector<uint8_t> sendData = GenerateRandomBytes(iteration, testDataSize);
|
||||
{
|
||||
if (iteration < testData.size())
|
||||
{
|
||||
if (testData[iteration].size() != 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl;
|
||||
}
|
||||
EXPECT_EQ(0, testData[iteration].size());
|
||||
testData[iteration] = sendData;
|
||||
}
|
||||
}
|
||||
|
||||
testSocket.SendFrame(sendData, true /*, context*/);
|
||||
auto response = testSocket.ReceiveFrame(context);
|
||||
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
|
||||
auto binaryResult = response->AsBinaryFrame();
|
||||
|
||||
// Make sure we get back the data we sent in the echo request.
|
||||
if (binaryResult->Data.size() == 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl;
|
||||
}
|
||||
EXPECT_EQ(sendData.size(), binaryResult->Data.size());
|
||||
{
|
||||
// There is no ordering expectation on the results, so we just remember the data
|
||||
// as it comes in. We'll make sure we received everything later on.
|
||||
if (iteration < receivedData.size())
|
||||
{
|
||||
if (receivedData[iteration].size() != 0)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
EXPECT_EQ(0, receivedData[iteration].size());
|
||||
receivedData[iteration] = binaryResult->Data;
|
||||
}
|
||||
}
|
||||
} while (std::chrono::system_clock::now() - startTime < testDuration);
|
||||
// Close the socket gracefully.
|
||||
testSocket.Close();
|
||||
}
|
||||
catch (Azure::Core::OperationCancelledException& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration
|
||||
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
|
||||
cancellationExceptions++;
|
||||
}
|
||||
catch (std::exception const& ex)
|
||||
{
|
||||
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
|
||||
exceptions++;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait for all the threads to exit.
|
||||
for (auto& thread : threads)
|
||||
{
|
||||
thread.join();
|
||||
}
|
||||
|
||||
// We no longer need to worry about synchronization since all the worker threads are done.
|
||||
GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl;
|
||||
GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex
|
||||
<< testData.size() << ")" << std::endl;
|
||||
EXPECT_GE(testDataLength, iterationCount.load());
|
||||
|
||||
// Resize the test data to the number of actual iterations.
|
||||
testData.resize(iterationCount.load());
|
||||
receivedData.resize(iterationCount.load());
|
||||
|
||||
// If we've processed every iteration, let's make sure that we received everything we sent.
|
||||
// If we dropped some results, then we can't check to ensure that we have received everything
|
||||
// because we can't account for everything sent.
|
||||
std::multiset<std::string> testDataStrings;
|
||||
std::multiset<std::string> receivedDataStrings;
|
||||
for (auto const& data : testData)
|
||||
{
|
||||
testDataStrings.emplace(ToHexString(data));
|
||||
}
|
||||
for (auto const& data : receivedData)
|
||||
{
|
||||
receivedDataStrings.emplace(ToHexString(data));
|
||||
}
|
||||
|
||||
EXPECT_EQ(testDataStrings, receivedDataStrings);
|
||||
for (auto const& data : testDataStrings)
|
||||
{
|
||||
if (receivedDataStrings.count(data) != testDataStrings.count(data))
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data)
|
||||
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
|
||||
<< " Missing Data: " << data << std::endl;
|
||||
}
|
||||
EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data));
|
||||
}
|
||||
for (auto const& data : receivedDataStrings)
|
||||
{
|
||||
if (testDataStrings.count(data) != receivedDataStrings.count(data))
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data)
|
||||
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
|
||||
<< " Missing Data: " << data << std::endl;
|
||||
}
|
||||
|
||||
EXPECT_NE(testDataStrings.end(), testDataStrings.find(data));
|
||||
}
|
||||
|
||||
// We shouldn't have seen any exceptions during the run.
|
||||
EXPECT_EQ(0, exceptions.load());
|
||||
EXPECT_EQ(0, cancellationExceptions.load());
|
||||
}
|
||||
|
||||
// Does not work because curl rejects the wss: scheme.
|
||||
class LibWebSocketIncrementProtocol {
|
||||
WebSocketOptions m_options{{"dumb-increment-protocol"}};
|
||||
WebSocket m_socket;
|
||||
|
||||
public:
|
||||
LibWebSocketIncrementProtocol() : m_socket{Azure::Core::Url("wss://libwebsockets.org"), m_options}
|
||||
{
|
||||
}
|
||||
|
||||
void Open() { m_socket.Open(); }
|
||||
int GetNextNumber()
|
||||
{
|
||||
// Time out in 5 seconds if no activity.
|
||||
Azure::Core::Context contextWithTimeout
|
||||
= Azure::Core::Context().WithDeadline(std::chrono::system_clock::now() + 10s);
|
||||
auto work = m_socket.ReceiveFrame(contextWithTimeout);
|
||||
if (work->FrameType == WebSocketFrameType::TextFrameReceived)
|
||||
{
|
||||
auto frame = work->AsTextFrame();
|
||||
return std::atoi(frame->Text.c_str());
|
||||
}
|
||||
if (work->FrameType == WebSocketFrameType::BinaryFrameReceived)
|
||||
{
|
||||
auto frame = work->AsBinaryFrame();
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
else if (work->FrameType == WebSocketFrameType::PeerClosedReceived)
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Remote server closed connection." << std::endl;
|
||||
throw std::runtime_error("Remote server closed connection.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unknown result type");
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() { m_socket.SendFrame("reset\n", true); }
|
||||
void RequestClose() { m_socket.SendFrame("closeme\n", true); }
|
||||
void Close() { m_socket.Close(); }
|
||||
void Close(uint16_t closeCode, std::string const& reasonText = {})
|
||||
{
|
||||
m_socket.Close(closeCode, reasonText);
|
||||
}
|
||||
void ConsumeUntilClosed()
|
||||
{
|
||||
while (m_socket.IsOpen())
|
||||
{
|
||||
auto work = m_socket.ReceiveFrame();
|
||||
if (work->FrameType == WebSocketFrameType::PeerClosedReceived)
|
||||
{
|
||||
auto peerClose = work->AsPeerCloseFrame();
|
||||
GTEST_LOG_(INFO) << "Peer closed. Remote Code: " << std::dec << peerClose->RemoteStatusCode
|
||||
<< " (0x" << std::hex << peerClose->RemoteStatusCode << ")" << std::endl;
|
||||
if (!peerClose->RemoteCloseReason.empty())
|
||||
{
|
||||
GTEST_LOG_(INFO) << " Peer Closed Data: " << peerClose->RemoteCloseReason;
|
||||
}
|
||||
GTEST_LOG_(INFO) << std::endl;
|
||||
return;
|
||||
}
|
||||
else if (work->FrameType == WebSocketFrameType::TextFrameReceived)
|
||||
{
|
||||
auto frame = work->AsTextFrame();
|
||||
GTEST_LOG_(INFO) << "Ignoring " << frame->Text << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class LibWebSocketStatus {
|
||||
|
||||
public:
|
||||
std::string GetLWSStatus()
|
||||
{
|
||||
WebSocketOptions options;
|
||||
|
||||
options.ServiceName = "websockettest";
|
||||
// Send 3 protocols to LWS.
|
||||
options.Protocols.push_back("brownCow");
|
||||
options.Protocols.push_back("lws-status");
|
||||
options.Protocols.push_back("flibbityflobbidy");
|
||||
WebSocket serverSocket(Azure::Core::Url("wss://libwebsockets.org"), options);
|
||||
serverSocket.Open();
|
||||
|
||||
// The server should have chosen the lws-status protocol since it doesn't understand the other
|
||||
// protocols.
|
||||
EXPECT_EQ("lws-status", serverSocket.GetNegotiatedProtocol());
|
||||
std::string returnValue;
|
||||
std::shared_ptr<WebSocketFrame> lwsStatus;
|
||||
do
|
||||
{
|
||||
|
||||
lwsStatus = serverSocket.ReceiveFrame();
|
||||
EXPECT_EQ(WebSocketFrameType::TextFrameReceived, lwsStatus->FrameType);
|
||||
if (lwsStatus->FrameType == WebSocketFrameType::TextFrameReceived)
|
||||
{
|
||||
auto textFrame = lwsStatus->AsTextFrame();
|
||||
returnValue.insert(returnValue.end(), textFrame->Text.begin(), textFrame->Text.end());
|
||||
}
|
||||
} while (!lwsStatus->IsFinalFrame);
|
||||
serverSocket.Close();
|
||||
return returnValue;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(WebSocketTests, LibWebSocketOrgLwsStatus)
|
||||
{
|
||||
{
|
||||
LibWebSocketStatus lwsStatus;
|
||||
auto serverStatus = lwsStatus.GetLWSStatus();
|
||||
GTEST_LOG_(INFO) << "Server status: " << serverStatus << std::endl;
|
||||
|
||||
Azure::Core::Json::_internal::json status;
|
||||
EXPECT_NO_THROW(status = Azure::Core::Json::_internal::json::parse(serverStatus));
|
||||
EXPECT_TRUE(status["conns"].is_array());
|
||||
auto& connections = status["conns"].get_ref<std::vector<Azure::Core::Json::_internal::json>&>();
|
||||
bool foundOurConnection = false;
|
||||
|
||||
// Scan through the list of connections to find a connection from the websockettest.
|
||||
for (auto& connection : connections)
|
||||
{
|
||||
EXPECT_TRUE(connection["ua"].is_string());
|
||||
auto userAgent = connection["ua"].get<std::string>();
|
||||
if (userAgent.find("websockettest") != std::string::npos)
|
||||
{
|
||||
foundOurConnection = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(foundOurConnection);
|
||||
}
|
||||
}
|
||||
TEST_F(WebSocketTests, LibWebSocketOrgIncrement)
|
||||
{
|
||||
{
|
||||
LibWebSocketIncrementProtocol incrementProtocol;
|
||||
incrementProtocol.Open();
|
||||
|
||||
// Note that we cannot practically validate the numbers received from the service because
|
||||
// they may be in flight at the time the "Reset" call is made.
|
||||
for (auto i = 0; i < 100; i += 1)
|
||||
{
|
||||
if (i % 5 == 0)
|
||||
{
|
||||
GTEST_LOG_(INFO) << "Reset" << std::endl;
|
||||
incrementProtocol.Reset();
|
||||
}
|
||||
int number = incrementProtocol.GetNextNumber();
|
||||
GTEST_LOG_(INFO) << "Got next number " << number << std::endl;
|
||||
}
|
||||
incrementProtocol.RequestClose();
|
||||
incrementProtocol.ConsumeUntilClosed();
|
||||
}
|
||||
}
|
||||
#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
|
||||
TEST_F(WebSocketTests, CurlTransportCoverage)
|
||||
{
|
||||
{
|
||||
|
||||
Azure::Core::Http::WebSockets::CurlWebSocketTransportOptions transportOptions;
|
||||
transportOptions.HttpKeepAlive = false;
|
||||
auto transport
|
||||
= std::make_shared<Azure::Core::Http::WebSockets::CurlWebSocketTransport>(transportOptions);
|
||||
|
||||
EXPECT_THROW(transport->NativeCloseSocket(1001, {}, {}), std::runtime_error);
|
||||
EXPECT_THROW(transport->NativeGetCloseSocketInformation({}), std::runtime_error);
|
||||
EXPECT_THROW(
|
||||
transport->NativeSendFrame(WebSocketTransport::NativeWebSocketFrameType::Binary, {}, {}),
|
||||
std::runtime_error);
|
||||
EXPECT_THROW(transport->NativeReceiveFrame({}), std::runtime_error);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
Loading…
Reference in New Issue
Block a user