feat: Add WebSocket transport implementation

- Extract core WebSocket files from original Larry Osterman implementation
- Add WebSocket headers and implementation to build system
- Update HttpPipeline API usage to current version
- Include CURL and WinHTTP WebSocket adapters
- Add WebSocket test infrastructure

Successfully compiles with current main (azure-core target)

Files added:
- sdk/core/azure-core/inc/azure/core/http/websockets/*
- sdk/core/azure-core/src/http/websockets/*
- sdk/core/azure-core/src/http/curl/curl_websockets.cpp
- sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp
- sdk/core/azure-core/test/ut/websocket_test.cpp

Cherry-picked from: 5cff286c0 (Initial WebSockets implementation)
This commit is contained in:
Ryan Hurey 2025-11-21 22:40:35 +00:00
parent 5eed2ccafd
commit a090ae75da
13 changed files with 3601 additions and 2 deletions

View File

@ -43,18 +43,22 @@ if(BUILD_TRANSPORT_CURL)
src/http/curl/curl_connection_pool_private.hpp
src/http/curl/curl_connection_private.hpp
src/http/curl/curl_session_private.hpp
src/http/curl/curl_websockets.cpp
)
SET(CURL_TRANSPORT_ADAPTER_INC
inc/azure/core/http/curl_transport.hpp
inc/azure/core/http/websockets/curl_websockets_transport.hpp
)
endif()
if(BUILD_TRANSPORT_WINHTTP)
SET(WIN_TRANSPORT_ADAPTER_SRC
src/http/winhttp/win_http_transport.cpp
src/http/winhttp/win_http_request.hpp
src/http/winhttp/win_http_websockets.cpp
)
SET(WIN_TRANSPORT_ADAPTER_INC
inc/azure/core/http/win_http_transport.hpp
inc/azure/core/http/websockets/win_http_websockets_transport.hpp
)
endif()
@ -80,6 +84,8 @@ set(
inc/azure/core/http/policies/policy.hpp
inc/azure/core/http/raw_response.hpp
inc/azure/core/http/transport.hpp
inc/azure/core/http/websockets/websockets.hpp
inc/azure/core/http/websockets/websockets_transport.hpp
inc/azure/core/internal/client_options.hpp
inc/azure/core/internal/contract.hpp
inc/azure/core/internal/credentials/authorization_challenge_parser.hpp
@ -142,6 +148,8 @@ set(
src/http/transport_policy.cpp
src/http/url.cpp
src/http/user_agent.cpp
src/http/websockets/websockets.cpp
src/http/websockets/websockets_impl.cpp
src/io/body_stream.cpp
src/io/random_access_file_body_stream.cpp
src/logger.cpp

View File

@ -0,0 +1,155 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via CURL.
*/
#pragma once
#include "azure/core/context.hpp"
#include "azure/core/http/curl_transport.hpp"
#include "azure/core/http/http.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/http/websockets/websockets_transport.hpp"
#include <memory>
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
struct CurlWebSocketTransportOptions : public Azure::Core::Http::CurlTransportOptions
{
};
/**
* @brief Concrete implementation of a WebSocket Transport that uses libcurl.
*/
class CurlWebSocketTransport : public CurlTransport, public WebSocketTransport {
public:
/**
* @brief Construct a new CurlWebSocketTransport object.
*
* @param options Optional parameter to override the default options.
*/
CurlWebSocketTransport(
CurlWebSocketTransportOptions const& options = CurlWebSocketTransportOptions())
: CurlTransport(options)
{
}
/**
* @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse
*
* @param request an HTTP Request to be send.
* @param context A context to control the request lifetime.
*
* @return unique ptr to an HTTP RawResponse.
*/
virtual std::unique_ptr<RawResponse> Send(Request& request, Context const& context) override;
/**
* @brief Indicates if the transport natively supports websockets or not.
*
* @details For the CURL websocket transport, the transport does NOT support native websockets -
* it is the responsibility of the client of the WebSocketTransport to format WebSocket protocol
* elements.
*/
virtual bool HasBuiltInWebSocketSupport() override { return false; }
/**
* @brief Closes the WebSocket handle.
*
*/
virtual void Close() override;
// Native WebSocket support methods.
/**
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
* The first param is the close reason, the second is descriptive text.
*/
virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&)
override
{
throw std::runtime_error("Not implemented.");
}
/**
* @brief Retrieve the status of the close socket operation.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
*/
NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
const Azure::Core::Context&) override
{
throw std::runtime_error("Not implemented");
}
/**
* @brief Send a frame of data to the remote node.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
*/
virtual void NativeSendFrame(
NativeWebSocketFrameType,
std::vector<uint8_t> const&,
Azure::Core::Context const&) override
{
throw std::runtime_error("Not implemented.");
}
/**
* @brief Receive a frame of data from the remote node.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
*/
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
Azure::Core::Context const&) override
{
throw std::runtime_error("Not implemented");
}
// Non-Native WebSocket support.
/**
* @brief This function is used when working with streams to pull more data from the wire.
* Function will try to keep pulling data from socket until the buffer is all written or until
* there is no more data to get from the socket.
*
* @param buffer Buffer to fill with data.
* @param bufferSize Size of buffer.
* @param context Context to control the request lifetime.
*
* @returns Buffer data received.
*
*/
virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context)
override;
/**
* @brief This method will use libcurl socket to write all the bytes from buffer.
*
* @param buffer Buffer to send.
* @param bufferSize Number of bytes to write.
* @param context Context for the operation.
*/
virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context)
override;
/**
* @brief returns true if this transport supports WebSockets, false otherwise.
*/
bool HasWebSocketSupport() const override { return true; }
private:
// std::unique_ptr cannot be constructed on an incomplete type (CurlNetworkConnection), but
// std::shared_ptr can be.
std::shared_ptr<Azure::Core::Http::CurlNetworkConnection> m_upgradedConnection;
void OnUpgradedConnection(
std::unique_ptr<Azure::Core::Http::CurlNetworkConnection>&& upgradedConnection) override;
};
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -0,0 +1,404 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Azure Core APIs implementing the WebSocket protocol [RFC 6455]
* (https://www.rfc-editor.org/rfc/rfc6455.html).
*/
#pragma once
#include "azure/core/context.hpp"
#include "azure/core/http/http.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/internal/client_options.hpp"
#include <string>
#include <vector>
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
namespace _detail {
class WebSocketImplementation;
}
namespace _internal {
enum class WebSocketFrameType : int
{
Unknown,
TextFrameReceived,
BinaryFrameReceived,
PeerClosedReceived,
};
enum class WebSocketErrorCode : uint16_t
{
OK = 1000,
EndpointDisappearing = 1001,
ProtocolError = 1002,
UnknownDataType = 1003,
Reserved1 = 1004,
NoStatusCodePresent = 1005,
ConnectionClosedWithoutCloseFrame = 1006,
InvalidMessageData = 1007,
PolicyViolation = 1008,
MessageTooLarge = 1009,
ExtensionNotFound = 1010,
UnexpectedError = 1011,
TlsHandshakeFailure = 1015,
};
class WebSocketTextFrame;
class WebSocketBinaryFrame;
class WebSocketPeerCloseFrame;
namespace _detail {
class WebSocketImplementation;
}
/** @brief Statistics about data sent and received by the WebSocket.
*
* @remarks This class is primarily intended for test collateral and debugging to allow
* a caller to determine information about the status of a WebSocket.
*
* Note: Some of these statistics are not available if the underlying transport supports native
* websockets.
*/
struct WebSocketStatistics
{
/** @brief The number of WebSocket frames sent on this WebSocket. */
uint32_t FramesSent;
/** @brief The number of bytes of data sent to the peer on this WebSocket. */
uint32_t BytesSent;
/** @brief The number of WebSocket frames received from the peer. */
uint32_t FramesReceived;
/** @brief The number of bytes received from the peer. */
uint32_t BytesReceived;
/** @brief The number of "Ping" frames received from the peer. */
uint32_t PingFramesReceived;
/** @brief The number of "Ping" frames sent to the peer. */
uint32_t PingFramesSent;
/** @brief The number of "Pong" frames received from the peer. */
uint32_t PongFramesReceived;
/** @brief The number of "Pong" frames sent to the peer. */
uint32_t PongFramesSent;
/** @brief The number of "Text" frames received from the peer. */
uint32_t TextFramesReceived;
/** @brief The number of "Text" frames sent to the peer. */
uint32_t TextFramesSent;
/** @brief The number of "Binary" frames received from the peer. */
uint32_t BinaryFramesReceived;
/** @brief The number of "Binary" frames sent to the peer. */
uint32_t BinaryFramesSent;
/** @brief The number of "Continuation" frames sent to the peer. */
uint32_t ContinuationFramesSent;
/** @brief The number of "Continuation" frames received from the peer. */
uint32_t ContinuationFramesReceived;
/** @brief The number of "Close" frames received from the peer. */
uint32_t CloseFramesReceived;
/** @brief The number of frames received which were not processed. */
uint32_t FramesDropped;
/** @brief The number of frames received which were not returned because they were received
* after the Close() method was called. */
uint32_t FramesDroppedByClose;
/** @brief The number of frames dropped because they were over the maximum payload size. */
uint32_t FramesDroppedByPayloadSizeLimit;
/** @brief The number of frames dropped because they were out of compliance with the protocol.
*/
uint32_t FramesDroppedByProtocolError;
/** @brief The number of reads performed on the transport.*/
uint32_t TransportReads;
/** @brief The number of bytes read from the transport. */
uint32_t TransportReadBytes;
};
/** @brief A frame of data received from a WebSocket.
*/
class WebSocketFrame {
public:
/** @brief The type of frame received: Text, Binary or Close. */
WebSocketFrameType FrameType{};
/** @brief True if the frame received is a "final" frame */
bool IsFinalFrame{false};
/** @brief Returns the contents of the frame as a Text frame.
* @returns A WebSocketTextFrame containing the contents of the frame.
*/
std::shared_ptr<WebSocketTextFrame> AsTextFrame();
/** @brief Returns the contents of the frame as a Binary frame.
* @returns A WebSocketBinaryFrame containing the contents of the frame.
*/
std::shared_ptr<WebSocketBinaryFrame> AsBinaryFrame();
/** @brief Returns the contents of the frame as a Peer Close frame.
* @returns A WebSocketPeerCloseFrame containing the contents of the frame.
*/
std::shared_ptr<WebSocketPeerCloseFrame> AsPeerCloseFrame();
/** @brief Construct a new instance of a WebSocketFrame.*/
WebSocketFrame() = default;
/** @brief Construct a new instance of a WebSocketFrame with a specific frame type.
* @param frameType The type of frame received.
*/
WebSocketFrame(WebSocketFrameType frameType) : FrameType{frameType} {}
/** @brief Construct a new instance of a WebSocketFrame with a specific frame type and final
* flag.
* @param frameType The type of frame received.
* @param isFinalFrame true if the frame is the final frame.
*/
WebSocketFrame(WebSocketFrameType frameType, bool isFinalFrame)
: FrameType{frameType}, IsFinalFrame{isFinalFrame}
{
}
};
/** @brief Contains the contents of a WebSocket Text frame.*/
class WebSocketTextFrame : public WebSocketFrame,
public std::enable_shared_from_this<WebSocketTextFrame> {
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
private:
public:
/** @brief Constructs a new WebSocketTextFrame */
WebSocketTextFrame() : WebSocketFrame(WebSocketFrameType::TextFrameReceived){};
/** @brief Text of the frame received from the remote peer. */
std::string Text;
private:
/** @brief Constructs a new WebSocketTextFrame
* @param isFinalFrame True if this is the final frame in a multi-frame message.
* @param body UTF-8 encoded text of the frame data.
* @param size Length in bytes of the frame body.
*/
WebSocketTextFrame(bool isFinalFrame, uint8_t const* body, size_t size)
: WebSocketFrame{WebSocketFrameType::TextFrameReceived, isFinalFrame},
Text(body, body + size)
{
}
};
/** @brief Contains the contents of a WebSocket Binary frame.*/
class WebSocketBinaryFrame : public WebSocketFrame,
public std::enable_shared_from_this<WebSocketBinaryFrame> {
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
private:
public:
/** @brief Constructs a new WebSocketBinaryFrame */
WebSocketBinaryFrame() : WebSocketFrame(WebSocketFrameType::BinaryFrameReceived){};
/** @brief Binary frame data received from the remote peer. */
std::vector<uint8_t> Data;
/** @brief Constructs a new WebSocketBinaryFrame
* @param isFinal True if this is the final frame in a multi-frame message.
* @param body binary of the frame data.
* @param size Length in bytes of the frame body.
*/
private:
WebSocketBinaryFrame(bool isFinal, uint8_t const* body, size_t size)
: WebSocketFrame{WebSocketFrameType::BinaryFrameReceived, isFinal},
Data(body, body + size)
{
}
};
/** @brief Contains the contents of a WebSocket Close frame.*/
class WebSocketPeerCloseFrame : public WebSocketFrame,
public std::enable_shared_from_this<WebSocketPeerCloseFrame> {
friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation;
public:
/** @brief Constructs a new WebSocketPeerCloseFrame */
WebSocketPeerCloseFrame() : WebSocketFrame(WebSocketFrameType::PeerClosedReceived){};
/** @brief Status code sent from the remote peer. Typically a member of the WebSocketErrorCode
* enumeration */
uint16_t RemoteStatusCode{};
/** @brief Optional text sent from the remote peer. */
std::string RemoteCloseReason;
private:
/** @brief Constructs a new WebSocketBinaryFrame
* @param remoteStatusCode Status code sent by the remote peer.
* @param remoteCloseReason Optional reason sent by the remote peer.
*/
WebSocketPeerCloseFrame(uint16_t remoteStatusCode, std::string const& remoteCloseReason)
: WebSocketFrame{WebSocketFrameType::PeerClosedReceived},
RemoteStatusCode(remoteStatusCode), RemoteCloseReason(remoteCloseReason)
{
}
};
struct WebSocketOptions : Azure::Core::_internal::ClientOptions
{
/**
* @brief The set of protocols which are supported by this client
*/
std::vector<std::string> Protocols = {};
/**
* @brief The protocol name of the service client. Used for the User-Agent header
* in the initial WebSocket handshake.
*/
std::string ServiceName;
/**
* @brief The version of the service client. Used for the User-Agent header in the
* initial WebSocket handshake
*/
std::string ServiceVersion;
/**
* @brief The period of time between ping operations, default is 60 seconds.
*/
std::chrono::duration<int64_t> PingInterval{std::chrono::seconds{60}};
/**
* @brief Construct an instance of a WebSocketOptions type.
*
* @param protocols Supported protocols for this websocket client.
*/
explicit WebSocketOptions(std::vector<std::string> protocols)
: Azure::Core::_internal::ClientOptions{}, Protocols(protocols)
{
}
WebSocketOptions() = default;
};
class WebSocket {
public:
/** @brief Constructs a new instance of a WebSocket with the specified WebSocket options.
*
* @param remoteUrl The URL of the remote WebSocket server.
* @param options The options to use for the WebSocket.
*/
explicit WebSocket(
Azure::Core::Url const& remoteUrl,
WebSocketOptions const& options = WebSocketOptions{});
/** @brief Destroys an instance of a WebSocket.
*/
~WebSocket();
/** @brief Opens a WebSocket connection to a remote server.
*
* @param context Context for the operation, used for cancellation and timeout.
*/
void Open(Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief Closes a WebSocket connection to the remote server gracefully.
*
* @param context Context for the operation.
*/
void Close(Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief Closes a WebSocket connection to the remote server with additional context.
*
* @param closeStatus 16 bit WebSocket error code.
* @param closeReason String describing the reason for closing the socket.
* @param context Context for the operation.
*/
void Close(
uint16_t closeStatus,
std::string const& closeReason = {},
Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief Sends a String frame to the remote server.
*
* @param textFrame UTF-8 encoded text to send.
* @param isFinalFrame if True, this is the final frame in a multi-frame message.
* @param context Context for the operation.
*/
void SendFrame(
std::string const& textFrame,
bool isFinalFrame = false,
Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief Sends a Binary frame to the remote server.
*
* @param binaryFrame Binary data to send.
* @param isFinalFrame if True, this is the final frame in a multi-frame message.
* @param context Context for the operation.
*/
void SendFrame(
std::vector<uint8_t> const& binaryFrame,
bool isFinalFrame = false,
Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief Receive a frame from the remote server.
*
* @param context Context for the operation.
*
* @returns The received WebSocket frame.
*
*/
std::shared_ptr<WebSocketFrame> ReceiveFrame(
Azure::Core::Context const& context = Azure::Core::Context{});
/** @brief AddHeader - Adds a header to the initial handshake.
*
* @note This API is ignored after the WebSocket is opened.
*
* @param headerName Name of header to add to the initial handshake request.
* @param headerValue Value of header to add.
*/
void AddHeader(std::string const& headerName, std::string const& headerValue);
/** @brief Determine if the WebSocket is open.
*
* @returns true if the WebSocket is open, false otherwise.
*/
bool IsOpen() const;
/** @brief Returns "true" if the configured websocket transport
* supports websockets in the transport, or if the websocket implementation
* is providing websocket protocol support.
*
* @returns true if the HTTP transport used for WebSocket support directly supports the
* WebSocket API.
*/
bool HasBuiltInWebSocketSupport() const;
/** @brief Returns the protocol chosen by the remote server during the initial handshake.
*
* @returns The protocol negotiated between client and server.
*/
std::string const& GetNegotiatedProtocol() const;
/** @brief Returns statistics about the WebSocket.
*
* @returns The statistics about the WebSocket.
*/
WebSocketStatistics GetStatistics() const;
private:
std::unique_ptr<Azure::Core::Http::WebSockets::_detail::WebSocketImplementation>
m_socketImplementation;
};
} // namespace _internal
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -0,0 +1,204 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief Utilities to be used by HTTP WebSocket transport implementations.
*/
#pragma once
#include "azure/core/context.hpp"
#include "azure/core/http/http.hpp"
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
/**
* @brief Base class for all WebSocket transport implementations.
*/
class WebSocketTransport {
public:
/**
* @brief Web Socket Frame type, one of Text or Binary.
*/
enum class NativeWebSocketFrameType
{
/**
* @brief Indicates that the frame is a partial UTF-8 encoded text frame - it is NOT the
* complete frame to be sent to the remote node.
*/
TextFragment,
/**
* @brief Indicates that the frame is either the complete UTF-8 encoded text frame to be sent
* to the remote node or the final frame of a multipart message.
*/
Text,
/**
* @brief Indicates that the frame is either the complete binary frame to be sent
* to the remote node or the final frame of a multipart message.
*/
Binary,
/**
* @brief Indicates that the frame is a partial binary frame - it is NOT the
* complete frame to be sent to the remote node.
*/
BinaryFragment,
/**
* @brief Indicates that the frame is a "close" frame - the remote node
* sent a close frame.
*/
Closed,
};
/** @brief Close information returned from a WebSocket transport that has builtin support
* for WebSockets.
*/
struct NativeWebSocketCloseInformation
{
/**
* @brief Close response code.
*/
uint16_t CloseReason;
/**
* @brief Close reason.
*/
std::string CloseReasonDescription;
};
/** @brief Frame information returned from a WebSocket transport that has builtin support
* for WebSockets.
*/
struct NativeWebSocketReceiveInformation
{
/**
* @brief Type of frame received.
*/
NativeWebSocketFrameType FrameType;
/**
* @brief Data received.
*/
std::vector<uint8_t> FrameData;
};
/**
* @brief Destructs `%WebSocketTransport`.
*
*/
virtual ~WebSocketTransport() {}
/**
* @brief Indicates whether the transport natively supports WebSockets.
*
* @returns true if the transport has native websocket support, false otherwise.
*/
virtual bool HasBuiltInWebSocketSupport() = 0;
/**
* @brief Closes the WebSocket.
*
* Does not notify the remote endpoint that the socket is being closed.
*
*/
virtual void Close() = 0;
/**************/
/* Native WebSocket support functions*/
/**************/
/**
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
*
* @param status Status value to be sent to the remote node. Application defined.
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
* @param context Context for the operation.
*/
virtual void NativeCloseSocket(
uint16_t status,
std::string const& disconnectReason,
Azure::Core::Context const& context)
= 0;
/**
* @brief Retrieve the information associated with a WebSocket close response.
*
* @param context Context for the operation.
*
* @returns a tuple containing the status code and string.
*/
virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
Azure::Core::Context const& context)
= 0;
/**
* @brief Send a frame of data to the remote node.
*
* @param frameType Frame type sent to the server, Text or Binary.
* @param frameData Frame data to be sent to the server.
* @param context Context for the operation.
*/
virtual void NativeSendFrame(
NativeWebSocketFrameType frameType,
std::vector<uint8_t> const& frameData,
Azure::Core::Context const& context)
= 0;
/**
* @brief Receive a frame from the remote WebSocket server.
*
* @param context Context for the operation.
*
* @returns a tuple containing the Frame data received from the remote server and the type of
* data returned from the remote endpoint
*/
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
Azure::Core::Context const& context)
= 0;
/**************/
/* Non Native WebSocket support functions */
/**************/
/**
* @brief This function is used when working with streams to pull more data from the wire.
* Function will try to keep pulling data from socket until the buffer is all written or until
* there is no more data to get from the socket.
*
*/
virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) = 0;
/**
* @brief This method will use the raw socket to write all the bytes from buffer.
*
*/
virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) = 0;
protected:
/**
* @brief Constructs a default instance of `%WebSocketTransport`.
*
*/
WebSocketTransport() = default;
/**
* @brief Constructs `%HttpTransport` by copying another instance of `%HttpTransport`.
*
* @param other An instance to copy.
*/
WebSocketTransport(const WebSocketTransport& other) = default;
/**
* @brief Constructs a WebSocketTransport from another WebSocketTransport.
*
* @param other An instance to move in.
*/
WebSocketTransport(WebSocketTransport&& other) = default;
/**
* @brief Assigns one WebSocketTransport to another.
*
* @param other An instance to assign.
*
* @return A reference to this instance.
*/
WebSocketTransport& operator=(const WebSocketTransport& other) = default;
};
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
/**
* @file
* @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via WInHTTP.
*/
#pragma once
#include "azure/core/context.hpp"
#include "azure/core/http/http.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/http/websockets/websockets_transport.hpp"
#include "azure/core/http/win_http_transport.hpp"
#include <memory>
#include <mutex>
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
/**
* @brief Concrete implementation of a WebSocket Transport that uses WinHTTP.
*/
class WinHttpWebSocketTransport : public WebSocketTransport, public WinHttpTransport {
Azure::Core::Http::_detail::unique_HINTERNET m_socketHandle;
std::mutex m_sendMutex;
std::mutex m_receiveMutex;
// Called by the
void OnUpgradedConnection(
Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) override;
public:
/**
* @brief Construct a new WinHTTP WebSocket Transport.
*
* @param options Optional parameter to override the default options.
*/
WinHttpWebSocketTransport(WinHttpTransportOptions const& options = WinHttpTransportOptions())
: WinHttpTransport(options)
{
}
/**
* @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse
*
* @param request an HTTP Request to be send.
* @param context A context to control the request lifetime.
*
* @return unique ptr to an HTTP RawResponse.
*/
virtual std::unique_ptr<RawResponse> Send(Request& request, Context const& context) override;
/**
* @brief Indicates if the transports natively websockets or not.
*
* @details For the WinHTTP websocket transport, the WinHTTP API supports websockets.
*/
virtual bool HasBuiltInWebSocketSupport() override { return true; }
/**
* @brief Close the underlying WebSocket handle.
*
*/
virtual void Close() override;
// Native WebSocket support methods.
/**
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
* @param status Status value to be sent to the remote node. Application defined.
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
* @param context Context for the operation.
*
*/
virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&)
override;
/**
* @brief Retrieve the information associated with a WebSocket close response.
*
* Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType
*
* @param context Context for the operation.
*
* @returns a tuple containing the status code and string.
*/
virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation(
Azure::Core::Context const& context) override;
/**
* @brief Send a frame of data to the remote node.
*
* @details Not implemented for CURL websockets because CURL does not support native
* websockets.
*
* @brief frameType Frame type sent to the server, Text or Binary.
* @brief frameData Frame data to be sent to the server.
*/
virtual void NativeSendFrame(
NativeWebSocketFrameType,
std::vector<uint8_t> const&,
Azure::Core::Context const&) override;
virtual NativeWebSocketReceiveInformation NativeReceiveFrame(
Azure::Core::Context const&) override;
// Non-Native WebSocket support.
/**
* @brief This function is used when working with streams to pull more data from the wire.
* Function will try to keep pulling data from socket until the buffer is all written or
* until there is no more data to get from the socket.
*
* @details Not implemented for WinHTTP websockets because WinHTTP implements websockets
* natively.
*/
virtual size_t ReadFromSocket(uint8_t*, size_t, Context const&) override
{
throw std::runtime_error("Not implemented.");
}
/**
* @brief This method will use sockets to write all the bytes from buffer.
*
* @details Not implemented for WinHTTP websockets because WinHTTP implements websockets
* natively.
*
*/
virtual int SendBuffer(uint8_t const*, size_t, Context const&) override
{
throw std::runtime_error("Not implemented.");
}
bool HasWebSocketSupport() const override { return true; }
};
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/http/http.hpp"
#include "azure/core/http/policies/policy.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
#include "azure/core/internal/diagnostics/log.hpp"
#include "azure/core/platform.hpp"
// Private include
#include "curl_connection_private.hpp"
#if defined(AZ_PLATFORM_POSIX)
#include <poll.h> // for poll()
#include <sys/socket.h> // for socket shutdown
#elif defined(AZ_PLATFORM_WINDOWS)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#if !defined(NOMINMAX)
#define NOMINMAX
#endif
#include <winapifamily.h>
#include <winsock2.h> // for WSAPoll();
#endif
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
void CurlWebSocketTransport::Close() { m_upgradedConnection->Shutdown(); }
// Send an HTTP request to the remote server.
std::unique_ptr<RawResponse> CurlWebSocketTransport::Send(
Request& request,
Context const& context)
{
// CURL doesn't understand the ws and wss protocols, so change the URL to be http based.
std::string requestScheme(request.GetUrl().GetScheme());
if (requestScheme == "wss" || requestScheme == "ws")
{
if (requestScheme == "wss")
{
request.GetUrl().SetScheme("https");
}
else
{
request.GetUrl().SetScheme("http");
}
}
return CurlTransport::Send(request, context);
}
size_t CurlWebSocketTransport::ReadFromSocket(
uint8_t* buffer,
size_t bufferSize,
Context const& context)
{
return m_upgradedConnection->ReadFromSocket(buffer, bufferSize, context);
}
/**
* @brief This method will use libcurl socket to write all the bytes from buffer.
*
*/
int CurlWebSocketTransport::SendBuffer(
uint8_t const* buffer,
size_t bufferSize,
Context const& context)
{
return m_upgradedConnection->SendBuffer(buffer, bufferSize, context);
}
void CurlWebSocketTransport::OnUpgradedConnection(
std::unique_ptr<CurlNetworkConnection>&& upgradedConnection)
{
// Note that m_upgradedConnection is a std::shared_ptr. We define it as a std::shared_ptr
// because a std::shared_ptr can be declared on an incomplete type, while a std::unique_ptr
// cannot.
m_upgradedConnection = std::move(upgradedConnection);
}
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/http/websockets/websockets.hpp"
#include "azure/core/context.hpp"
#include "websockets_impl.hpp"
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _internal {
WebSocket::WebSocket(Azure::Core::Url const& remoteUrl, WebSocketOptions const& options)
: m_socketImplementation(
std::make_unique<Azure::Core::Http::WebSockets::_detail::WebSocketImplementation>(
remoteUrl,
options))
{
}
WebSocket::~WebSocket() {}
void WebSocket::Open(Azure::Core::Context const& context)
{
m_socketImplementation->Open(context);
}
void WebSocket::Close(Azure::Core::Context const& context)
{
m_socketImplementation->Close(
static_cast<uint16_t>(WebSocketErrorCode::EndpointDisappearing), {}, context);
}
void WebSocket::Close(
uint16_t closeStatus,
std::string const& closeReason,
Azure::Core::Context const& context)
{
m_socketImplementation->Close(closeStatus, closeReason, context);
}
void WebSocket::SendFrame(
std::string const& textFrame,
bool isFinalFrame,
Azure::Core::Context const& context)
{
m_socketImplementation->SendFrame(textFrame, isFinalFrame, context);
}
void WebSocket::SendFrame(
std::vector<uint8_t> const& binaryFrame,
bool isFinalFrame,
Azure::Core::Context const& context)
{
m_socketImplementation->SendFrame(binaryFrame, isFinalFrame, context);
}
WebSocketStatistics WebSocket::GetStatistics() const
{
return m_socketImplementation->GetStatistics();
}
bool WebSocket::HasBuiltInWebSocketSupport() const
{
return m_socketImplementation->HasBuiltInWebSocketSupport();
}
std::shared_ptr<WebSocketFrame> WebSocket::ReceiveFrame(Azure::Core::Context const& context)
{
return m_socketImplementation->ReceiveFrame(context);
}
void WebSocket::AddHeader(std::string const& headerName, std::string const& headerValue)
{
m_socketImplementation->AddHeader(headerName, headerValue);
}
std::string const& WebSocket::GetNegotiatedProtocol() const
{
return m_socketImplementation->GetNegotiatedProtocol();
}
bool WebSocket::IsOpen() const { return m_socketImplementation->IsOpen(); }
std::shared_ptr<WebSocketTextFrame> WebSocketFrame::AsTextFrame()
{
if (FrameType != WebSocketFrameType::TextFrameReceived)
{
throw std::logic_error("Cannot cast to TextFrameReceived.");
}
return static_cast<WebSocketTextFrame*>(this)->shared_from_this();
}
std::shared_ptr<WebSocketBinaryFrame> WebSocketFrame::AsBinaryFrame()
{
if (FrameType != WebSocketFrameType::BinaryFrameReceived)
{
throw std::logic_error("Cannot cast to BinaryFrameReceived.");
}
return static_cast<WebSocketBinaryFrame*>(this)->shared_from_this();
}
std::shared_ptr<WebSocketPeerCloseFrame> WebSocketFrame::AsPeerCloseFrame()
{
if (FrameType != WebSocketFrameType::PeerClosedReceived)
{
throw std::logic_error("Cannot cast to PeerClose.");
}
return static_cast<WebSocketPeerCloseFrame*>(this)->shared_from_this();
}
}}}}} // namespace Azure::Core::Http::WebSockets::_internal

View File

@ -0,0 +1,876 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "websockets_impl.hpp"
#include "azure/core/base64.hpp"
#include "azure/core/http/policies/policy.hpp"
#include "azure/core/internal/cryptography/sha_hash.hpp"
// SUPPORT_NATIVE_TRANSPORT indicates if WinHTTP should be compiled with native transport support
// or not.
// Note that this is primarily required to improve the code coverage numbers in the CI pipeline.
#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER)
#include "azure/core/http/websockets/win_http_websockets_transport.hpp"
#define SUPPORT_NATIVE_TRANSPORT 1
#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
#define SUPPORT_NATIVE_TRANSPORT 0
#endif
#include "azure/core/internal/diagnostics/log.hpp"
#include <algorithm>
#include <array>
#include <iomanip>
#include <mutex>
#include <random>
#include <shared_mutex>
#include <sstream>
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail {
using namespace Azure::Core::Http::WebSockets::_internal;
using namespace Azure::Core::Diagnostics::_internal;
using namespace Azure::Core::Diagnostics;
using namespace std::chrono_literals;
namespace {
std::string HexEncode(std::vector<uint8_t> const& data, size_t length)
{
std::stringstream ss;
for (size_t i = 0; i < std::min(data.size(), length); i++)
{
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(data[i]);
}
return ss.str();
}
} // namespace
WebSocketImplementation::WebSocketImplementation(
Azure::Core::Url const& remoteUrl,
WebSocketOptions const& options)
: m_remoteUrl(remoteUrl), m_options(options), m_pingThread(this, m_options.PingInterval)
{
}
void WebSocketImplementation::Open(Azure::Core::Context const& context)
{
if (m_state != SocketState::Invalid && m_state != SocketState::Closed)
{
throw std::runtime_error(
"Socket in unexpected state: " + std::to_string(static_cast<uint32_t>(m_state)));
}
m_state = SocketState::Opening;
#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER)
WinHttpTransportOptions transportOptions;
auto winHttpTransport
= std::make_shared<Azure::Core::Http::WebSockets::WinHttpWebSocketTransport>(
transportOptions);
m_transport = std::static_pointer_cast<WebSocketTransport>(winHttpTransport);
m_options.Transport.Transport = std::static_pointer_cast<HttpTransport>(winHttpTransport);
#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
CurlWebSocketTransportOptions transportOptions;
transportOptions.HttpKeepAlive = false;
auto curlWebSockets
= std::make_shared<Azure::Core::Http::WebSockets::CurlWebSocketTransport>(transportOptions);
m_transport = std::static_pointer_cast<WebSocketTransport>(curlWebSockets);
m_options.Transport.Transport = std::static_pointer_cast<HttpTransport>(curlWebSockets);
#endif
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> perCallPolicies{};
std::vector<std::unique_ptr<Azure::Core::Http::Policies::HttpPolicy>> perRetryPolicies{};
// If the caller has told us a service name, add the telemetry policy to the pipeline to add
// Create pipeline for WebSocket handshake
std::string serviceName = m_options.ServiceName.empty() ? "azure.core.websockets" : m_options.ServiceName;
std::string serviceVersion = m_options.ServiceVersion.empty() ? "1.0.0" : m_options.ServiceVersion;
Azure::Core::Http::_internal::HttpPipeline openPipeline(
m_options, serviceName, serviceVersion, std::move(perRetryPolicies), std::move(perCallPolicies));
Azure::Core::Http::Request openSocketRequest(
Azure::Core::Http::HttpMethod::Get, m_remoteUrl, false);
// Generate the random request key. Only used when the transport doesn't support websockets
// natively.
auto randomKey = GenerateRandomKey();
auto encodedKey = Azure::Core::Convert::Base64Encode(randomKey);
if (!m_transport->HasBuiltInWebSocketSupport())
{
// If the transport doesn't support WebSockets natively, set the standardized WebSocket
// upgrade headers.
openSocketRequest.SetHeader("Upgrade", "websocket");
openSocketRequest.SetHeader("Connection", "upgrade");
openSocketRequest.SetHeader("Sec-WebSocket-Version", "13");
openSocketRequest.SetHeader("Sec-WebSocket-Key", encodedKey);
}
if (!m_options.Protocols.empty())
{
std::string protocols;
for (auto const& protocol : m_options.Protocols)
{
protocols += protocol;
protocols += ", ";
}
protocols = protocols.substr(0, protocols.size() - 2);
openSocketRequest.SetHeader("Sec-WebSocket-Protocol", protocols);
}
for (auto const& additionalHeader : m_headers)
{
openSocketRequest.SetHeader(additionalHeader.first, additionalHeader.second);
}
std::string remoteOrigin;
remoteOrigin = m_remoteUrl.GetScheme();
remoteOrigin += "://";
remoteOrigin += m_remoteUrl.GetHost();
openSocketRequest.SetHeader("Origin", remoteOrigin);
// Send the connect request to the WebSocket server.
auto response = openPipeline.Send(openSocketRequest, context);
// Ensure that the server thinks we're switching protocols. If it doesn't,
// fail immediately.
if (response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::SwitchingProtocols)
{
throw Azure::Core::Http::TransportException("Unexpected handshake response");
}
// Prove that the server received this socket request.
auto& responseHeaders = response->GetHeaders();
if (!m_transport->HasBuiltInWebSocketSupport())
{
auto socketAccept(responseHeaders.find("Sec-WebSocket-Accept"));
if (socketAccept == responseHeaders.end())
{
throw Azure::Core::Http::TransportException("Missing Sec-WebSocket-Accept header");
}
// Verify that the WebSocket server received *this* open request.
else
{
VerifySocketAccept(encodedKey, socketAccept->second);
}
m_initialBodyStream = response->ExtractBodyStream();
m_pingThread.Start(m_transport);
}
// Remember the protocol that the client chose.
auto chosenProtocol = responseHeaders.find("Sec-WebSocket-Protocol");
if (chosenProtocol != responseHeaders.end())
{
m_chosenProtocol = chosenProtocol->second;
}
m_state = SocketState::Open;
}
bool WebSocketImplementation::HasBuiltInWebSocketSupport()
{
std::lock_guard<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Open)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
return m_transport->HasBuiltInWebSocketSupport();
}
std::string const& WebSocketImplementation::GetNegotiatedProtocol()
{
std::lock_guard<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Open)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
return m_chosenProtocol;
}
void WebSocketImplementation::AddHeader(std::string const& header, std::string const& headerValue)
{
std::lock_guard<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Closed && m_state != SocketState::Invalid)
{
throw std::runtime_error("AddHeader can only be called on closed sockets.");
}
m_headers.emplace(std::make_pair(header, headerValue));
}
void WebSocketImplementation::Close(
uint16_t closeStatus,
std::string const& closeReason,
Azure::Core::Context const& context)
{
std::unique_lock<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
// If we're closing an already closed socket, we're done.
if (m_state == SocketState::Closed)
{
return;
}
if (m_state != SocketState::Open)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
m_state = SocketState::Closing;
#if SUPPORT_NATIVE_TRANSPORT
if (m_transport->HasBuiltInWebSocketSupport())
{
m_transport->NativeCloseSocket(closeStatus, closeReason.c_str(), context);
}
else
#endif
{
// Send a going away message to the server.
std::vector<uint8_t> closePayload;
closePayload.push_back(closeStatus >> 8);
closePayload.push_back(closeStatus & 0xff);
closePayload.insert(closePayload.end(), closeReason.begin(), closeReason.end());
std::vector<uint8_t> closeFrame = EncodeFrame(SocketOpcode::Close, true, closePayload);
SendTransportBuffer(closeFrame, context);
// Unlock the state mutex before waiting for the close response to be received.
lock.unlock();
// Drain the incoming series of frames from the server.
// Note that there might be in-flight frames that were sent from the other end of the
// WebSocket that we don't care about any more (since we're closing the WebSocket). So
// drain those frames.
auto closeResponse = ReceiveTransportFrame(context);
while (closeResponse && closeResponse->Opcode != SocketOpcode::Close)
{
m_receiveStatistics.FramesDroppedByClose++;
Log::Write(
Logger::Level::Warning,
"Received unexpected frame during close. Opcode: "
+ std::to_string(static_cast<uint8_t>(closeResponse->Opcode)));
closeResponse = ReceiveTransportFrame(context);
}
// Re-acquire the state lock once we've received the close response.
lock.lock();
m_stateOwner = std::this_thread::get_id();
}
// Close the socket - after this point, the m_transport is invalid.
m_pingThread.Shutdown();
m_transport->Close();
m_state = SocketState::Closed;
}
void WebSocketImplementation::SendFrame(
std::string const& textFrame,
bool isFinalFrame,
Azure::Core::Context const& context)
{
std::lock_guard<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Open)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
std::vector<uint8_t> utf8text(textFrame.begin(), textFrame.end());
m_receiveStatistics.TextFramesSent++;
#if SUPPORT_NATIVE_TRANSPORT
if (m_transport->HasBuiltInWebSocketSupport())
{
m_transport->NativeSendFrame(
(isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Text
: WebSocketTransport::NativeWebSocketFrameType::TextFragment),
utf8text,
context);
}
else
#endif
{
std::vector<uint8_t> sendFrame = EncodeFrame(SocketOpcode::TextFrame, isFinalFrame, utf8text);
SendTransportBuffer(sendFrame, context);
}
}
void WebSocketImplementation::SendFrame(
std::vector<uint8_t> const& binaryFrame,
bool isFinalFrame,
Azure::Core::Context const& context)
{
std::lock_guard<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Open)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
m_receiveStatistics.BinaryFramesSent++;
#if SUPPORT_NATIVE_TRANSPORT
if (m_transport->HasBuiltInWebSocketSupport())
{
m_transport->NativeSendFrame(
(isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Binary
: WebSocketTransport::NativeWebSocketFrameType::BinaryFragment),
binaryFrame,
context);
}
else
#endif
{
std::vector<uint8_t> sendFrame
= EncodeFrame(SocketOpcode::BinaryFrame, isFinalFrame, binaryFrame);
SendTransportBuffer(sendFrame, context);
}
}
std::shared_ptr<WebSocketFrame> WebSocketImplementation::ReceiveFrame(
Azure::Core::Context const& context)
{
std::unique_lock<std::mutex> lock(m_stateMutex);
m_stateOwner = std::this_thread::get_id();
if (m_state != SocketState::Open && m_state != SocketState::Closing)
{
throw std::runtime_error(
"Socket is not open." + std::to_string(static_cast<uint32_t>(m_state)));
}
// Unlock the state lock to allow other threads to run. If we don't, we might end up in in a
// situation where the server won't respond to the this client because all the client threads
// are blocked on the state lock.
lock.unlock();
std::shared_ptr<WebSocketInternalFrame> frame;
// Loop until we receive an returnable incoming frame.
// If the incoming frame is returnable, we return the value from the frame.
while (true)
{
frame = ReceiveTransportFrame(context);
if (frame)
{
switch (frame->Opcode)
{
// When we receive a "ping" frame, we want to send a Pong frame back to the server.
case SocketOpcode::Ping:
Log::Write(
Logger::Level::Verbose, "Received Ping frame: " + HexEncode(frame->Payload, 16));
SendPong(frame->Payload, context);
break;
// We want to ignore all incoming "Pong" frames.
case SocketOpcode::Pong:
Log::Write(
Logger::Level::Verbose, "Received Pong frame: " + HexEncode(frame->Payload, 16));
break;
case SocketOpcode::BinaryFrame:
m_currentMessageType = SocketMessageType::Binary;
return std::shared_ptr<WebSocketFrame>(new WebSocketBinaryFrame(
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
case SocketOpcode::TextFrame:
m_currentMessageType = SocketMessageType::Text;
return std::shared_ptr<WebSocketFrame>(new WebSocketTextFrame(
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
case SocketOpcode::Close: {
if (frame->Payload.size() < 2)
{
throw std::runtime_error("Close response buffer is too short.");
}
// Encode the payload for close according to RFC 6455
// section 5.5.1. The first two bytes of the payload contain the status code.
// The remainder of the payload is a UTF-8 encoded string.
uint16_t errorCode = 0;
errorCode |= (frame->Payload[0] << 8) & 0xff00;
errorCode |= (frame->Payload[1] & 0x00ff);
// We received a close frame, mark the socket as closed. Make sure we
// reacquire the state lock before setting the state to closed.
lock.lock();
m_stateOwner = std::this_thread::get_id();
m_state = SocketState::Closed;
return std::shared_ptr<WebSocketFrame>(new WebSocketPeerCloseFrame(
errorCode, std::string(frame->Payload.begin() + 2, frame->Payload.end())));
}
// Continuation frames need to be treated somewhat specially.
// We depend on the fact that the protocol requires that a Continuation frame
// only be sent if it is part of a multi-frame message whose previous frame was a Text
// or Binary frame.
case SocketOpcode::Continuation:
if (m_currentMessageType == SocketMessageType::Text)
{
if (frame->IsFinalFrame)
{
m_currentMessageType = SocketMessageType::Unknown;
}
return std::shared_ptr<WebSocketFrame>(new WebSocketTextFrame(
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
}
else if (m_currentMessageType == SocketMessageType::Binary)
{
if (frame->IsFinalFrame)
{
m_currentMessageType = SocketMessageType::Unknown;
}
return std::shared_ptr<WebSocketFrame>(new WebSocketBinaryFrame(
frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size()));
}
else
{
m_receiveStatistics.FramesDroppedByProtocolError++;
throw std::runtime_error("Unknown message type and received continuation opcode");
}
default:
throw std::runtime_error("Unknown frame type received.");
}
}
else
{
if (m_state != SocketState::Closed && m_state != SocketState::Closing)
{
throw std::runtime_error("Transport is at EOF, no frame to receive.");
}
// The socket was closed, most likely locally, so fake a close frame response.
return std::shared_ptr<WebSocketFrame>(new WebSocketPeerCloseFrame());
}
context.ThrowIfCancelled();
}
}
std::shared_ptr<WebSocketImplementation::WebSocketInternalFrame>
WebSocketImplementation::ReceiveTransportFrame(Azure::Core::Context const& context)
{
#if SUPPORT_NATIVE_TRANSPORT
if (m_transport->HasBuiltInWebSocketSupport())
{
auto payload = m_transport->NativeReceiveFrame(context);
m_receiveStatistics.FramesReceived++;
switch (payload.FrameType)
{
case WebSocketTransport::NativeWebSocketFrameType::Binary:
m_receiveStatistics.BinaryFramesReceived++;
return std::make_shared<WebSocketInternalFrame>(
SocketOpcode::BinaryFrame, true, payload.FrameData);
case WebSocketTransport::NativeWebSocketFrameType::BinaryFragment:
m_receiveStatistics.BinaryFramesReceived++;
return std::make_shared<WebSocketInternalFrame>(
SocketOpcode::BinaryFrame, false, payload.FrameData);
case WebSocketTransport::NativeWebSocketFrameType::Text:
m_receiveStatistics.TextFramesReceived++;
return std::make_shared<WebSocketInternalFrame>(
SocketOpcode::TextFrame, true, payload.FrameData);
case WebSocketTransport::NativeWebSocketFrameType::TextFragment:
m_receiveStatistics.TextFramesReceived++;
return std::make_shared<WebSocketInternalFrame>(
SocketOpcode::TextFrame, false, payload.FrameData);
case WebSocketTransport::NativeWebSocketFrameType::Closed: {
m_receiveStatistics.CloseFramesReceived++;
auto closeResult = m_transport->NativeGetCloseSocketInformation(context);
std::vector<uint8_t> closePayload;
closePayload.push_back(closeResult.CloseReason >> 8);
closePayload.push_back(closeResult.CloseReason & 0xff);
closePayload.insert(
closePayload.end(),
closeResult.CloseReasonDescription.begin(),
closeResult.CloseReasonDescription.end());
return std::make_shared<WebSocketInternalFrame>(SocketOpcode::Close, true, closePayload);
}
default:
throw std::runtime_error("Unexpected frame type received.");
}
}
else
#endif
{
std::shared_ptr<WebSocketInternalFrame> frame = DecodeFrame(context);
if (frame)
{
// Handle statistics for the incoming frame.
m_receiveStatistics.FramesReceived++;
switch (frame->Opcode)
{
case SocketOpcode::Ping: {
m_receiveStatistics.PingFramesReceived++;
break;
}
case SocketOpcode::Pong: {
m_receiveStatistics.PongFramesReceived++;
break;
}
case SocketOpcode::TextFrame: {
m_receiveStatistics.TextFramesReceived++;
break;
}
case SocketOpcode::BinaryFrame: {
m_receiveStatistics.BinaryFramesReceived++;
break;
}
case SocketOpcode::Close: {
m_receiveStatistics.CloseFramesReceived++;
break;
}
case SocketOpcode::Continuation: {
m_receiveStatistics.ContinuationFramesReceived++;
break;
}
default: {
m_receiveStatistics.UnknownFramesReceived++;
break;
}
}
}
else
{
m_receiveStatistics.FramesDropped++;
}
return frame;
}
}
WebSocketStatistics WebSocketImplementation::GetStatistics() const
{
WebSocketStatistics returnValue{};
returnValue.FramesSent = m_receiveStatistics.FramesSent.load();
returnValue.FramesReceived = m_receiveStatistics.FramesReceived.load();
returnValue.BinaryFramesReceived = m_receiveStatistics.BinaryFramesReceived.load();
returnValue.TextFramesReceived = m_receiveStatistics.TextFramesReceived.load();
returnValue.BinaryFramesSent = m_receiveStatistics.BinaryFramesSent.load();
returnValue.TextFramesSent = m_receiveStatistics.TextFramesSent.load();
returnValue.PingFramesReceived = m_receiveStatistics.PingFramesReceived.load();
returnValue.PongFramesReceived = m_receiveStatistics.PongFramesReceived.load();
returnValue.PingFramesSent = m_receiveStatistics.PingFramesSent.load();
returnValue.PongFramesSent = m_receiveStatistics.PongFramesSent.load();
returnValue.BytesSent = m_receiveStatistics.BytesSent.load();
returnValue.BytesReceived = m_receiveStatistics.BytesReceived.load();
returnValue.FramesDropped = m_receiveStatistics.FramesDropped.load();
returnValue.FramesDroppedByClose = m_receiveStatistics.FramesDroppedByClose.load();
returnValue.FramesDroppedByPayloadSizeLimit
= m_receiveStatistics.FramesDroppedByPayloadSizeLimit.load();
returnValue.FramesDroppedByProtocolError
= m_receiveStatistics.FramesDroppedByProtocolError.load();
returnValue.TransportReadBytes = m_receiveStatistics.TransportReadBytes.load();
returnValue.TransportReads = m_receiveStatistics.TransportReads.load();
return returnValue;
}
std::vector<uint8_t> WebSocketImplementation::EncodeFrame(
SocketOpcode opcode,
bool isFinal,
std::vector<uint8_t> const& payload)
{
std::vector<uint8_t> encodedFrame;
// Add opcode+fin.
encodedFrame.push_back(static_cast<uint8_t>(opcode) | (isFinal ? 0x80 : 0));
uint8_t maskAndLength = 0;
maskAndLength |= 0x80;
// Payloads smaller than 125 bytes are encoded directly in the maskAndLength field.
uint64_t payloadSize = static_cast<uint64_t>(payload.size());
if (payloadSize <= 125)
{
maskAndLength |= static_cast<uint8_t>(payload.size());
}
else if (payloadSize <= 65535)
{
// Payloads greater than 125 whose size can fit in a 16 bit integer bytes
// are encoded as a 16 bit unsigned integer in network byte order.
maskAndLength |= 126;
}
else
{
// Payloads greater than 65536 have their length are encoded as a 64 bit unsigned integer
// in network byte order.
maskAndLength |= 127;
}
encodedFrame.push_back(maskAndLength);
// Encode a 16 bit length.
if (payloadSize > 125 && payloadSize <= 65535)
{
encodedFrame.push_back(static_cast<uint16_t>(payload.size()) >> 8);
encodedFrame.push_back(static_cast<uint16_t>(payload.size()) & 0xff);
}
// Encode a 64 bit length.
else if (payloadSize >= 65536)
{
encodedFrame.push_back((payloadSize >> 56) & 0xff);
encodedFrame.push_back((payloadSize >> 48) & 0xff);
encodedFrame.push_back((payloadSize >> 40) & 0xff);
encodedFrame.push_back((payloadSize >> 32) & 0xff);
encodedFrame.push_back((payloadSize >> 24) & 0xff);
encodedFrame.push_back((payloadSize >> 16) & 0xff);
encodedFrame.push_back((payloadSize >> 8) & 0xff);
encodedFrame.push_back(payloadSize & 0xff);
}
// Calculate the masking key. This MUST be 4 bytes of high entropy random numbers used to
// mask the input data.
{
// Start by generating the mask - 4 bytes of random data.
std::vector<uint8_t> mask = GenerateRandomBytes(4);
// Append the mask to the payload.
encodedFrame.insert(encodedFrame.end(), mask.begin(), mask.end());
// And mask the payload before transmitting it.
size_t index = 0;
for (auto ch : payload)
{
encodedFrame.push_back(ch ^ mask[index % 4]);
index += 1;
}
}
return encodedFrame;
}
std::shared_ptr<WebSocketImplementation::WebSocketInternalFrame>
WebSocketImplementation::DecodeFrame(Azure::Core::Context const& context)
{
// Ensure single threaded access to receive this frame.
std::unique_lock<std::mutex> lock(m_transportMutex);
if (IsTransportEof())
{
throw std::runtime_error("Frame buffer is too small.");
}
uint8_t payloadByte = ReadTransportByte(context);
// If the transport is at EOF, then there is no payload data, so just return null.
if (IsTransportEof())
{
return nullptr;
}
SocketOpcode opcode = static_cast<SocketOpcode>(payloadByte & 0x7f);
bool isFinal = (payloadByte & 0x80) != 0;
payloadByte = ReadTransportByte(context);
if (IsTransportEof())
{
return nullptr;
}
if (payloadByte & 0x80)
{
throw std::runtime_error("Server sent a frame with a reserved bit set.");
}
int64_t payloadLength = payloadByte & 0x7f;
if (payloadLength <= 125)
{
payloadByte += 1;
}
else if (payloadLength == 126)
{
payloadLength = ReadTransportShort(context);
}
else if (payloadLength == 127)
{
payloadLength = ReadTransportInt64(context);
}
else
{
throw std::logic_error("Unexpected payload length.");
}
if (IsTransportEof())
{
return nullptr;
}
std::vector<uint8_t> payload(ReadTransportBytes(static_cast<size_t>(payloadLength), context));
if (IsTransportEof())
{
return nullptr;
}
return std::make_shared<WebSocketInternalFrame>(opcode, isFinal, payload);
}
uint8_t WebSocketImplementation::ReadTransportByte(Azure::Core::Context const& context)
{
if (m_bufferPos >= m_bufferLen)
{
// Start by reading data from our initial body stream.
m_bufferLen = m_initialBodyStream->ReadToCount(m_buffer, m_bufferSize, context);
if (m_bufferLen == 0)
{
// If we run out of the initial stream, we need to read from the transport.
m_bufferLen = m_transport->ReadFromSocket(m_buffer, m_bufferSize, context);
m_receiveStatistics.TransportReads++;
m_receiveStatistics.TransportReadBytes += static_cast<uint32_t>(m_bufferLen);
}
else
{
Azure::Core::Diagnostics::_internal::Log::Write(
Azure::Core::Diagnostics::Logger::Level::Informational,
"Read data from initial stream");
}
m_bufferPos = 0;
if (m_bufferLen == 0)
{
m_eof = true;
return 0;
}
}
m_receiveStatistics.BytesReceived++;
return m_buffer[m_bufferPos++];
}
uint16_t WebSocketImplementation::ReadTransportShort(Azure::Core::Context const& context)
{
uint16_t result = ReadTransportByte(context);
result <<= 8;
result |= ReadTransportByte(context);
return result;
}
uint64_t WebSocketImplementation::ReadTransportInt64(Azure::Core::Context const& context)
{
uint64_t result = 0;
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 56 & 0xff00000000000000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 48 & 0x00ff000000000000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 40 & 0x0000ff0000000000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 32 & 0x000000ff00000000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 24 & 0x00000000ff000000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 16 & 0x0000000000ff0000);
result |= (static_cast<uint64_t>(ReadTransportByte(context)) << 8 & 0x000000000000ff00);
result |= static_cast<uint64_t>(ReadTransportByte(context));
return result;
}
std::vector<uint8_t> WebSocketImplementation::ReadTransportBytes(
size_t readLength,
Azure::Core::Context const& context)
{
std::vector<uint8_t> result;
size_t index = 0;
while (index < readLength)
{
uint8_t byte = ReadTransportByte(context);
result.push_back(byte);
index += 1;
}
return result;
}
void WebSocketImplementation::SendTransportBuffer(
std::vector<uint8_t> const& sendFrame,
Azure::Core::Context const& context)
{
std::unique_lock<std::mutex> transportLock(m_transportMutex);
m_receiveStatistics.BytesSent += static_cast<uint32_t>(sendFrame.size());
m_receiveStatistics.FramesSent += 1;
m_transport->SendBuffer(sendFrame.data(), sendFrame.size(), context);
}
// Verify the Sec-WebSocket-Accept header as defined in RFC 6455 Section 1.3, which defines
// the opening handshake used for establishing the WebSocket connection.
std::string acceptHeaderGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
void WebSocketImplementation::VerifySocketAccept(
std::string const& encodedKey,
std::string const& acceptHeader)
{
std::string concatenatedKey(encodedKey);
concatenatedKey += acceptHeaderGuid;
Azure::Core::Cryptography::_internal::Sha1Hash sha1hash;
sha1hash.Append(
reinterpret_cast<const uint8_t*>(concatenatedKey.data()), concatenatedKey.size());
auto keyHash = sha1hash.Final();
std::string encodedHash = Azure::Core::Convert::Base64Encode(keyHash);
if (encodedHash != acceptHeader)
{
throw std::runtime_error(
"Hash returned by WebSocket server does not match expected hash. Aborting");
}
}
WebSocketImplementation::PingThread::PingThread(
WebSocketImplementation* socketImplementation,
std::chrono::duration<int64_t> pingInterval)
: m_webSocketImplementation(socketImplementation), m_pingInterval(pingInterval)
{
}
void WebSocketImplementation::PingThread::Start(std::shared_ptr<WebSocketTransport> transport)
{
m_stop = false;
// Spin up a thread to receive data from the transport.
if (!transport->HasBuiltInWebSocketSupport())
{
std::unique_lock<std::mutex> lock(m_pingThreadStarted);
m_pingThread = std::thread{&PingThread::PingThreadLoop, this};
m_pingThreadReady.wait(lock);
}
}
WebSocketImplementation::PingThread::~PingThread()
{
// Ensure that the receive thread is stopped.
Shutdown();
}
void WebSocketImplementation::PingThread::Shutdown()
{
if (m_pingThread.joinable())
{
std::unique_lock<std::mutex> lock(m_stopMutex);
m_stop = true;
lock.unlock();
m_pingThreadStopped.notify_all();
m_pingThread.join();
}
}
void WebSocketImplementation::PingThread::PingThreadLoop()
{
Log::Write(Logger::Level::Verbose, "Start Ping Thread Loop.");
{
std::unique_lock<std::mutex> lock(m_pingThreadStarted);
m_pingThreadReady.notify_all();
}
while (true)
{
std::unique_lock<std::mutex> lock(m_stopMutex);
if (this->m_pingThreadStopped.wait_for(lock, m_pingInterval) == std::cv_status::timeout)
{
Log::Write(Logger::Level::Verbose, "Send Ping to peer.");
// The receiveContext timed out, this means we timed out our "ping" timeout.
// Send a "Ping" request to the remote node.
auto pingData = GenerateRandomBytes(4);
SendPing(pingData, Azure::Core::Context{});
}
if (m_stop)
{
Log::Write(Logger::Level::Verbose, "Exiting ping thread");
return;
}
}
}
bool WebSocketImplementation::PingThread::SendPing(
std::vector<uint8_t> const& pingData,
Azure::Core::Context const& context)
{
std::vector<uint8_t> pingFrame = EncodeFrame(SocketOpcode::Ping, true, pingData);
m_webSocketImplementation->m_receiveStatistics.PingFramesSent++;
m_webSocketImplementation->SendTransportBuffer(pingFrame, context);
return true;
}
void WebSocketImplementation::SendPong(
std::vector<uint8_t> const& pongData,
Azure::Core::Context const& context)
{
std::vector<uint8_t> pongFrame = EncodeFrame(SocketOpcode::Pong, true, pongData);
m_receiveStatistics.PongFramesSent++;
SendTransportBuffer(pongFrame, context);
}
// Generator for random bytes. Used in WebSocketImplementation and tests.
std::vector<uint8_t> GenerateRandomBytes(size_t vectorSize)
{
std::random_device randomEngine;
std::vector<uint8_t> rv(vectorSize);
std::generate(begin(rv), end(rv), [&randomEngine]() mutable {
return static_cast<uint8_t>(randomEngine() % UINT8_MAX);
});
return rv;
}
}}}}} // namespace Azure::Core::Http::WebSockets::_detail

View File

@ -0,0 +1,373 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/http/websockets/websockets.hpp"
#include "azure/core/http/websockets/websockets_transport.hpp"
#include "azure/core/internal/diagnostics/log.hpp"
#include "azure/core/internal/http/pipeline.hpp"
#include <array>
#include <condition_variable>
#include <queue>
#include <random>
#include <shared_mutex>
#include <thread>
// Implementation of WebSocket protocol.
namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail {
// Generator for random bytes. Used in WebSocketImplementation and tests.
std::vector<uint8_t> GenerateRandomBytes(size_t vectorSize);
class WebSocketImplementation {
enum class SocketState
{
Invalid,
Closed,
Opening,
Open,
Closing,
};
public:
WebSocketImplementation(
Azure::Core::Url const& remoteUrl,
_internal::WebSocketOptions const& options);
void Open(Azure::Core::Context const& context);
void Close(
uint16_t closeStatus,
std::string const& closeReason,
Azure::Core::Context const& context);
void SendFrame(
std::string const& textFrame,
bool isFinalFrame,
Azure::Core::Context const& context);
void SendFrame(
std::vector<uint8_t> const& binaryFrame,
bool isFinalFrame,
Azure::Core::Context const& context);
std::shared_ptr<_internal::WebSocketFrame> ReceiveFrame(Azure::Core::Context const& context);
void AddHeader(std::string const& headerName, std::string const& headerValue);
std::string const& GetNegotiatedProtocol();
bool IsOpen() { return m_state == SocketState::Open; }
bool HasBuiltInWebSocketSupport();
_internal::WebSocketStatistics GetStatistics() const;
private:
// WebSocket opcodes.
enum class SocketOpcode : uint8_t
{
Continuation = 0x00,
TextFrame = 0x01,
BinaryFrame = 0x02,
Close = 0x08,
Ping = 0x09,
Pong = 0x0a
};
/**
* Indicates the type of the message currently being processed. Used when processing
* Continuation Opcode frames.
*/
enum class SocketMessageType : int
{
Unknown,
Text,
Binary,
};
class WebSocketInternalFrame {
public:
SocketOpcode Opcode{};
bool IsFinalFrame{false};
std::vector<uint8_t> Payload;
std::exception_ptr Exception;
WebSocketInternalFrame(
SocketOpcode opcode,
bool isFinalFrame,
std::vector<uint8_t> const& payload)
: Opcode(opcode), IsFinalFrame(isFinalFrame), Payload(payload)
{
}
WebSocketInternalFrame(std::exception_ptr exception) : Exception(exception) {}
};
struct ReceiveStatistics
{
std::atomic<uint32_t> FramesSent;
std::atomic<uint32_t> FramesReceived;
std::atomic<uint32_t> BytesSent;
std::atomic<uint32_t> BytesReceived;
std::atomic<uint32_t> PingFramesSent;
std::atomic<uint32_t> PingFramesReceived;
std::atomic<uint32_t> PongFramesSent;
std::atomic<uint32_t> PongFramesReceived;
std::atomic<uint32_t> TextFramesReceived;
std::atomic<uint32_t> BinaryFramesReceived;
std::atomic<uint32_t> ContinuationFramesReceived;
std::atomic<uint32_t> CloseFramesReceived;
std::atomic<uint32_t> UnknownFramesReceived;
std::atomic<uint32_t> FramesDropped;
std::atomic<uint32_t> FramesDroppedByPayloadSizeLimit;
std::atomic<uint32_t> FramesDroppedByProtocolError;
std::atomic<uint32_t> TransportReads;
std::atomic<uint32_t> TransportReadBytes;
std::atomic<uint32_t> BinaryFramesSent;
std::atomic<uint32_t> TextFramesSent;
std::atomic<uint32_t> FramesDroppedByClose;
void Reset()
{
FramesSent = 0;
BytesSent = 0;
FramesReceived = 0;
BytesReceived = 0;
PingFramesReceived = 0;
PingFramesSent = 0;
PongFramesReceived = 0;
PongFramesSent = 0;
TextFramesReceived = 0;
TextFramesSent = 0;
BinaryFramesReceived = 0;
BinaryFramesSent = 0;
ContinuationFramesReceived = 0;
CloseFramesReceived = 0;
UnknownFramesReceived = 0;
FramesDropped = 0;
FramesDroppedByClose = 0;
FramesDroppedByPayloadSizeLimit = 0;
FramesDroppedByProtocolError = 0;
TransportReads = 0;
TransportReadBytes = 0;
}
};
/**
* @brief The PingThread handles sending Ping operations from the WebSocket server.
*
*/
class PingThread {
public:
/**
* @brief Construct a new ReceiveQueue object.
*
* @param webSocketImplementation Parent object, used to send Ping threads.
* @param pingInterval Interval to wait between sending pings.
*/
PingThread(
WebSocketImplementation* webSocketImplementation,
std::chrono::duration<int64_t> pingInterval);
/**
* @brief Destroys a ReceiveQueue object. Blocks until the queue thread is completed.
*/
~PingThread();
/**
* @brief Start the receive queue. This will start a thread that will process incoming frames.
*
* @param transport The websocket transport to use for receiving frames.
*/
void Start(std::shared_ptr<WebSocketTransport> transport);
/**
* @brief Stop the receive queue. This will stop the thread that processes incoming frames.
*/
void Shutdown();
private:
/**
* @brief The receive queue thread.
*/
void PingThreadLoop();
/**
* @brief Send a "ping" frame to the other side of the WebSocket.
*
* @returns True if the ping was sent, false if the underlying transport didn't support "Ping"
* operations.
*/
bool SendPing(std::vector<uint8_t> const& pingData, Azure::Core::Context const& context);
WebSocketImplementation* m_webSocketImplementation;
std::chrono::duration<int64_t> m_pingInterval;
std::thread m_pingThread;
std::mutex m_pingThreadStarted;
std::condition_variable m_pingThreadReady;
std::mutex m_stopMutex;
std::condition_variable m_pingThreadStopped;
bool m_stop = false;
};
/**
* @brief Encode a websocket frame according to RFC 6455 section 5.2.
*
* This wire format for the data transfer part is described by the ABNF
* [RFC5234] given in detail in this section. (Note that, unlike in
* other sections of this document, the ABNF in this section is
* operating on groups of bits. The length of each group of bits is
* indicated in a comment. When encoded on the wire, the most
* significant bit is the leftmost in the ABNF). A high-level overview
* of the framing is given in the following figure. In a case of
* conflict between the figure below and the ABNF specified later in
* this section, the figure is authoritative.
*
* 0 1 2 3
* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
* +-+-+-+-+-------+-+-------------+-------------------------------+
* |F|R|R|R| opcode|M| Payload len | Extended payload length |
* |I|S|S|S| (4) |A| (7) | (16/64) |
* |N|V|V|V| |S| | (if payload len==126/127) |
* | |1|2|3| |K| | |
* +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
* | Extended payload length continued, if payload len == 127 |
* + - - - - - - - - - - - - - - - +-------------------------------+
* | |Masking-key, if MASK set to 1 |
* +-------------------------------+-------------------------------+
* | Masking-key (continued) | Payload Data |
* +-------------------------------- - - - - - - - - - - - - - - - +
* : Payload Data continued ... :
* + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
* | Payload Data continued ... |
* +---------------------------------------------------------------+
*
* FIN: 1 bit
*
* Indicates that this is the final fragment in a message. The first
* fragment MAY also be the final fragment.
*
* RSV1, RSV2, RSV3: 1 bit each
*
* MUST be 0 unless an extension is negotiated that defines meanings
* for non-zero values. If a nonzero value is received and none of
* the negotiated extensions defines the meaning of such a nonzero
* value, the receiving endpoint MUST _Fail the WebSocket
* Connection_.
*
* Opcode: 4 bits
*
* Defines the interpretation of the "Payload data". If an unknown
* opcode is received, the receiving endpoint MUST _Fail the
* WebSocket Connection_. The following values are defined.
*
* * %x0 denotes a continuation frame
*
* * %x1 denotes a text frame
*
* * %x2 denotes a binary frame
*
* * %x3-7 are reserved for further non-control frames
*
* * %x8 denotes a connection close
*
* * %x9 denotes a ping
*
* * %xA denotes a pong
*
* * %xB-F are reserved for further control frames
*
* Mask: 1 bit
*
* Defines whether the "Payload data" is masked. If set to 1, a
* masking key is present in masking-key, and this is used to unmask
* the "Payload data" as per Section 5.3. All frames sent from
* client to server have this bit set to 1.
*
* Payload length: 7 bits, 7+16 bits, or 7+64 bits
*
* The length of the "Payload data", in bytes: if 0-125, that is the
* payload length. If 126, the following 2 bytes interpreted as a
* 16-bit unsigned integer are the payload length. If 127, the
* following 8 bytes interpreted as a 64-bit unsigned integer (the
* most significant bit MUST be 0) are the payload length. Multibyte
* length quantities are expressed in network byte order. Note that
* in all cases, the minimal number of bytes MUST be used to encode
* the length, for example, the length of a 124-byte-long string
* can't be encoded as the sequence 126, 0, 124. The payload length
* is the length of the "Extension data" + the length of the
* "Application data". The length of the "Extension data" may be
* zero, in which case the payload length is the length of the
* "Application data".
* Masking-key: 0 or 4 bytes
*
* All frames sent from the client to the server are masked by a
* 32-bit value that is contained within the frame. This field is
* present if the mask bit is set to 1 and is absent if the mask bit
* is set to 0. See Section 5.3 for further information on client-
* to-server masking.
*
* Payload data: (x+y) bytes
*
* The "Payload data" is defined as "Extension data" concatenated
* with "Application data".
*
* Extension data: x bytes
*
* The "Extension data" is 0 bytes unless an extension has been
* negotiated. Any extension MUST specify the length of the
* "Extension data", or how that length may be calculated, and how
* the extension use MUST be negotiated during the opening handshake.
* If present, the "Extension data" is included in the total payload
* length.
*
* Application data: y bytes
*
* Arbitrary "Application data", taking up the remainder of the frame
* after any "Extension data". The length of the "Application data"
* is equal to the payload length minus the length of the "Extension
* data".
*/
static std::vector<uint8_t> EncodeFrame(
SocketOpcode opcode,
bool isFinal,
std::vector<uint8_t> const& payload);
SocketState m_state{SocketState::Invalid};
std::vector<uint8_t> GenerateRandomKey() { return GenerateRandomBytes(16); };
void VerifySocketAccept(std::string const& encodedKey, std::string const& acceptHeader);
/*********
* Buffered Read Support. Read data from the underlying transport into a buffer.
*/
uint8_t ReadTransportByte(Azure::Core::Context const& context);
uint16_t ReadTransportShort(Azure::Core::Context const& context);
uint64_t ReadTransportInt64(Azure::Core::Context const& context);
std::vector<uint8_t> ReadTransportBytes(size_t readLength, Azure::Core::Context const& context);
bool IsTransportEof() const { return m_eof; }
void SendPong(std::vector<uint8_t> const& pongData, Azure::Core::Context const& context);
void SendTransportBuffer(
std::vector<uint8_t> const& payload,
Azure::Core::Context const& context);
std::shared_ptr<WebSocketInternalFrame> ReceiveTransportFrame(
Azure::Core::Context const& context);
/**
* @brief Decode a frame received from the websocket server.
*
* @returns A pointer to the start of the decoded data.
*/
std::shared_ptr<WebSocketInternalFrame> DecodeFrame(Azure::Core::Context const& context);
Azure::Core::Url m_remoteUrl;
_internal::WebSocketOptions m_options;
std::map<std::string, std::string> m_headers;
std::string m_chosenProtocol;
std::shared_ptr<Azure::Core::Http::WebSockets::WebSocketTransport> m_transport;
PingThread m_pingThread;
SocketMessageType m_currentMessageType{SocketMessageType::Unknown};
std::mutex m_stateMutex;
std::thread::id m_stateOwner;
ReceiveStatistics m_receiveStatistics{};
std::mutex m_transportMutex;
std::unique_ptr<Azure::Core::IO::BodyStream> m_initialBodyStream;
constexpr static size_t m_bufferSize = 1024;
uint8_t m_buffer[m_bufferSize]{};
size_t m_bufferPos = 0;
size_t m_bufferLen = 0;
bool m_eof = false;
};
}}}}} // namespace Azure::Core::Http::WebSockets::_detail

View File

@ -0,0 +1,221 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "azure/core/http/http.hpp"
#include "azure/core/http/policies/policy.hpp"
#include "azure/core/http/transport.hpp"
#include "azure/core/http/websockets/win_http_websockets_transport.hpp"
#include "azure/core/internal/diagnostics/log.hpp"
#include "azure/core/platform.hpp"
#if defined(AZ_PLATFORM_POSIX)
#include <poll.h> // for poll()
#include <sys/socket.h> // for socket shutdown
#elif defined(AZ_PLATFORM_WINDOWS)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#if !defined(NOMINMAX)
#define NOMINMAX
#endif
#include <winapifamily.h>
#include <winsock2.h> // for WSAPoll();
#endif
#include <shared_mutex>
namespace Azure { namespace Core { namespace Http { namespace WebSockets {
void WinHttpWebSocketTransport::OnUpgradedConnection(
Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle)
{
// Convert the request handle into a WebSocket handle for us to use later.
m_socketHandle = Azure::Core::Http::_detail::unique_HINTERNET(
WinHttpWebSocketCompleteUpgrade(requestHandle.get(), 0),
Azure::Core::Http::_detail::HINTERNET_deleter{});
if (!m_socketHandle)
{
GetErrorAndThrow("Error Upgrading HttpRequest handle to WebSocket handle.");
}
}
std::unique_ptr<Azure::Core::Http::RawResponse> WinHttpWebSocketTransport::Send(
Azure::Core::Http::Request& request,
Azure::Core::Context const& context)
{
return WinHttpTransport::Send(request, context);
}
/**
* @brief Close the WebSocket cleanly.
*/
void WinHttpWebSocketTransport::Close() { m_socketHandle.reset(); }
// Native WebSocket support methods.
/**
* @brief Gracefully closes the WebSocket, notifying the remote node of the close reason.
*
* @details Not implemented for CURL websockets because CURL does not support native websockets.
*
* @param status Status value to be sent to the remote node. Application defined.
* @param disconnectReason UTF-8 encoded reason for the disconnection. Optional.
* @param context Context for the operation.
*
*/
void WinHttpWebSocketTransport::NativeCloseSocket(
uint16_t status,
std::string const& disconnectReason,
Azure::Core::Context const& context)
{
context.ThrowIfCancelled();
auto err = WinHttpWebSocketClose(
m_socketHandle.get(),
status,
disconnectReason.empty()
? nullptr
: reinterpret_cast<PVOID>(const_cast<char*>(disconnectReason.c_str())),
static_cast<DWORD>(disconnectReason.size()));
if (err != 0)
{
GetErrorAndThrow("WinHttpWebSocketClose() failed", err);
}
context.ThrowIfCancelled();
// Make sure that the server responds gracefully to the close request.
auto closeInformation = NativeGetCloseSocketInformation(context);
// The server should return the same status we sent.
if (closeInformation.CloseReason != status)
{
throw std::runtime_error(
"Close status mismatch, got " + std::to_string(closeInformation.CloseReason)
+ " expected " + std::to_string(status));
}
}
/**
* @brief Retrieve the information associated with a WebSocket close response.
*
* Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType
*
* @param context Context for the operation.
*
* @returns a tuple containing the status code and string.
*/
WinHttpWebSocketTransport::NativeWebSocketCloseInformation
WinHttpWebSocketTransport::NativeGetCloseSocketInformation(Azure::Core::Context const& context)
{
context.ThrowIfCancelled();
uint16_t closeStatus = 0;
char closeReason[WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH]{};
DWORD closeReasonLength;
auto err = WinHttpWebSocketQueryCloseStatus(
m_socketHandle.get(),
&closeStatus,
closeReason,
WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH,
&closeReasonLength);
if (err != 0)
{
GetErrorAndThrow("WinHttpGetCloseStatus() failed", err);
}
return NativeWebSocketCloseInformation{closeStatus, std::string(closeReason)};
}
/**
* @brief Send a frame of data to the remote node.
*
* @details Not implemented for CURL websockets because CURL does not support native
* websockets.
*
* @brief frameType Frame type sent to the server, Text or Binary.
* @brief frameData Frame data to be sent to the server.
*/
void WinHttpWebSocketTransport::NativeSendFrame(
NativeWebSocketFrameType frameType,
std::vector<uint8_t> const& frameData,
Azure::Core::Context const& context)
{
context.ThrowIfCancelled();
WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType;
switch (frameType)
{
case NativeWebSocketFrameType::Text:
bufferType = WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE;
break;
case NativeWebSocketFrameType::Binary:
bufferType = WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE;
break;
case NativeWebSocketFrameType::BinaryFragment:
bufferType = WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE;
break;
case NativeWebSocketFrameType::TextFragment:
bufferType = WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE;
break;
default:
throw std::runtime_error(
"Unknown frame type: " + std::to_string(static_cast<uint32_t>(frameType)));
break;
}
// Lock the socket to prevent concurrent writes. WinHTTP gets annoyed if
// there are multiple WinHttpWebSocketSend requests outstanding.
std::lock_guard<std::mutex> lock(m_sendMutex);
auto err = WinHttpWebSocketSend(
m_socketHandle.get(),
bufferType,
reinterpret_cast<PVOID>(const_cast<uint8_t*>(frameData.data())),
static_cast<DWORD>(frameData.size()));
if (err != 0)
{
GetErrorAndThrow("WinHttpWebSocketSend() failed", err);
}
}
WinHttpWebSocketTransport::NativeWebSocketReceiveInformation
WinHttpWebSocketTransport::NativeReceiveFrame(Azure::Core::Context const& context)
{
WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType;
NativeWebSocketFrameType frameTypeReceived;
DWORD bufferBytesRead;
std::vector<uint8_t> buffer(128);
context.ThrowIfCancelled();
std::lock_guard<std::mutex> lock(m_receiveMutex);
auto err = WinHttpWebSocketReceive(
m_socketHandle.get(),
reinterpret_cast<PVOID>(buffer.data()),
static_cast<DWORD>(buffer.size()),
&bufferBytesRead,
&bufferType);
if (err != 0 && err != ERROR_INSUFFICIENT_BUFFER)
{
GetErrorAndThrow("WinHttpWebSocketReceive() failed", err);
}
buffer.resize(bufferBytesRead);
switch (bufferType)
{
case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE:
frameTypeReceived = NativeWebSocketFrameType::Text;
break;
case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE:
frameTypeReceived = NativeWebSocketFrameType::Binary;
break;
case WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE:
frameTypeReceived = NativeWebSocketFrameType::BinaryFragment;
break;
case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE:
frameTypeReceived = NativeWebSocketFrameType::TextFragment;
break;
case WINHTTP_WEB_SOCKET_CLOSE_BUFFER_TYPE:
frameTypeReceived = NativeWebSocketFrameType::Closed;
break;
default:
throw std::runtime_error("Unknown frame type: " + std::to_string(bufferType));
break;
}
return NativeWebSocketReceiveInformation{frameTypeReceived, buffer};
}
}}}} // namespace Azure::Core::Http::WebSockets

View File

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation. All rights reserved.
# SPDX-License-Identifier: MIT
param(
[string] $LogFileLocation = "$($env:BUILD_SOURCESDIRECTORY)/WebSocketServer.log"
)

View File

@ -0,0 +1,155 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# SPDX-License-Identifier: MIT
from array import array
import asyncio
from operator import length_hint
import threading
from time import sleep
from urllib.parse import ParseResult, urlparse
import websockets
# create handler for each connection
customPaths = {}
stop = False
async def handleControlPath(websocket):
while (1):
data : str = await websocket.recv()
parsedCommand = data.split(' ')
if (parsedCommand[0] == "close"):
print("Closing control channel")
await websocket.send("ok")
print("Terminating WebSocket server.")
stop.set_result(0)
break
elif parsedCommand[0] == "newPath":
print("Add path")
newPath = parsedCommand[1]
print(" Add path ", newPath)
customPaths[newPath] = {"path": newPath, "delay": int(parsedCommand[2]) }
await websocket.send("ok")
else:
print("Unknown command, echoing it.")
await websocket.send(data)
async def handleCustomPath(websocket, path:dict):
print("Handle custom path", path)
data : str = await websocket.recv()
print("Received ", data)
if ("delay" in path.keys()):
sleep(path["delay"])
print("Responding")
await websocket.send(data)
await websocket.close()
def HexEncode(data: bytes)->str:
rv=""
for val in data:
rv+= '{:02X}'.format(val)
return rv
def ParseQuery(url : ParseResult) -> dict:
rv={}
if len(url.query)!=0:
args = url.query.split('&')
for arg in args:
vals=arg.split('=')
rv[vals[0]]=vals[1]
return rv
echo_count_lock = threading.Lock()
echo_count_recv = 0
echo_count_send = 0
client_count = 0
async def handleEcho(websocket, url:ParseResult):
global client_count
global echo_count_recv
global echo_count_send
global echo_count_lock
queryValues = ParseQuery(url)
while websocket.open:
try:
data = await websocket.recv()
with echo_count_lock:
echo_count_recv+=1
if 'delay' in queryValues:
print(f"sleeping for {queryValues['delay']} seconds")
await asyncio.sleep(float(queryValues['delay']))
print("woken up.")
if 'fragment' in queryValues and queryValues['fragment']=='true':
await websocket.send(data.split())
else:
await websocket.send(data)
with echo_count_lock:
echo_count_send+=1
except websockets.ConnectionClosedOK:
print("Connection closed ok.")
with echo_count_lock:
client_count -= 1
print(f"Echo count: {echo_count_recv}, {echo_count_send} client_count {client_count}")
if client_count == 0:
echo_count_send = 0
echo_count_recv = 0
return
except websockets.ConnectionClosed as ex:
if (ex.rcvd):
print(f"Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}")
else:
print(f"Connection closed. No close information.")
with echo_count_lock:
client_count -= 1
print(f"Echo count: recv: {echo_count_recv}, send: {echo_count_send} client_count {client_count}")
if client_count == 0:
echo_count_send = 0
echo_count_recv = 0
return
async def handler(websocket, path : str):
global client_count
print("Socket handler: ", path)
parsedUrl = urlparse(path)
if (parsedUrl.path == '/openclosetest'):
print("Open/Close Test")
try:
data = await websocket.recv()
print(f"OpenCloseTest: Received {data}")
except websockets.ConnectionClosedOK:
print("OpenCloseTest: Connection closed ok.")
except websockets.ConnectionClosed as ex:
print(f"OpenCloseTest: Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}")
return
elif (parsedUrl.path == '/echotest'):
with echo_count_lock:
client_count+= 1
await handleEcho(websocket, parsedUrl)
elif (parsedUrl.path == '/closeduringecho'):
data = await websocket.recv()
await websocket.close(1001, 'closed')
elif (parsedUrl.path =='/control'):
await handleControlPath(websocket)
elif (parsedUrl.path in customPaths.keys()):
print("Found path ", path, "in control paths.")
await handleCustomPath(websocket, customPaths[path])
elif (parsedUrl.path == '/terminateserver'):
print("Terminating WebSocket server.")
stop.set_result(0)
else:
data = await websocket.recv()
print("Received: ", data)
reply = f"Data received as: {data}!"
await websocket.send(reply)
async def main():
global stop
print("Starting server")
loop = asyncio.get_running_loop()
stop = loop.create_future()
async with websockets.serve(handler, "localhost", 8000, ping_interval=7):
await stop # run forever.
if __name__=="__main__":
asyncio.run(main())
print("Ending server")

View File

@ -0,0 +1,875 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MIT
#include "../../src/http/websockets/websockets_impl.hpp"
#include "azure/core/http/websockets/websockets.hpp"
#include "azure/core/internal/json/json.hpp"
#include <chrono>
#include <gtest/gtest.h>
#include <list>
#include <set>
#include <thread>
#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
#include "azure/core/http/websockets/curl_websockets_transport.hpp"
#endif
// cspell::words closeme flibbityflobbidy
using namespace Azure::Core;
using namespace Azure::Core::Http::WebSockets;
using namespace Azure::Core::Http::WebSockets::_internal;
using namespace std::chrono_literals;
constexpr uint16_t UndefinedButLegalCloseReason = 4500;
class WebSocketTests : public testing::Test {
private:
protected:
// Create
static void SetUpTestSuite() {}
static void TearDownTestSuite() {}
};
TEST_F(WebSocketTests, CreateSimpleSocket)
{
{
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000"));
defaultSocket.AddHeader("newHeader", "headerValue");
EXPECT_THROW(defaultSocket.GetNegotiatedProtocol(), std::runtime_error);
}
}
TEST_F(WebSocketTests, OpenSimpleSocket)
{
{
WebSocketOptions options;
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"), options);
defaultSocket.AddHeader("newHeader", "headerValue");
defaultSocket.Open();
EXPECT_THROW(defaultSocket.AddHeader("newHeader", "headerValue"), std::runtime_error);
// Close the socket without notifying the peer.
defaultSocket.Close();
}
{
WebSocketOptions options;
WebSocket defaultSocket(Azure::Core::Url("http://www.microsoft.com/"), options);
defaultSocket.AddHeader("newHeader", "headerValue");
// When running this test locally, the call times out, so drop in a 5 second timeout on
// the request.
Azure::Core::Context requestContext = Azure::Core::Context::ApplicationContext.WithDeadline(
std::chrono::system_clock::now() + 5s);
EXPECT_THROW(defaultSocket.Open(requestContext), std::runtime_error);
}
}
TEST_F(WebSocketTests, OpenAndCloseSocket)
{
if (false)
{
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"));
defaultSocket.AddHeader("newHeader", "headerValue");
defaultSocket.Open();
// Close the socket without notifying the peer.
defaultSocket.Close(UndefinedButLegalCloseReason);
}
{
WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"));
defaultSocket.Open();
// Close the socket without notifying the peer.
defaultSocket.Close(UndefinedButLegalCloseReason, "This is a good reason.");
//
// Now re-open the socket - this should work to reset everything.
defaultSocket.Open();
EXPECT_THROW(defaultSocket.Open(), std::runtime_error);
defaultSocket.Close();
}
}
TEST_F(WebSocketTests, SimpleEcho)
{
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
testSocket.Open();
testSocket.SendFrame("Test message", true);
auto response = testSocket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::TextFrameReceived, response->FrameType);
EXPECT_THROW(response->AsBinaryFrame(), std::logic_error);
auto textResult = response->AsTextFrame();
EXPECT_EQ("Test message", textResult->Text);
// Close the socket gracefully.
testSocket.Close();
}
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?delay=5"));
testSocket.Open();
std::vector<uint8_t> binaryData{1, 2, 3, 4, 5, 6};
testSocket.SendFrame(binaryData, true);
auto response = testSocket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
EXPECT_THROW(response->AsPeerCloseFrame(), std::logic_error);
EXPECT_THROW(response->AsTextFrame(), std::logic_error);
auto textResult = response->AsBinaryFrame();
EXPECT_EQ(binaryData, textResult->Data);
// Close the socket gracefully.
testSocket.Close();
}
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?fragment=true&delay=5"));
testSocket.Open();
std::vector<uint8_t> binaryData{1, 2, 3, 4, 5, 6};
testSocket.SendFrame(binaryData, true);
std::vector<uint8_t> responseData;
std::shared_ptr<WebSocketFrame> response;
do
{
response = testSocket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
auto binaryResult = response->AsBinaryFrame();
responseData.insert(responseData.end(), binaryResult->Data.begin(), binaryResult->Data.end());
} while (!response->IsFinalFrame);
auto textResult = response->AsBinaryFrame();
EXPECT_EQ(binaryData, responseData);
// Close the socket gracefully.
testSocket.Close();
}
}
template <size_t N> void EchoRandomData(WebSocket& socket)
{
std::vector<uint8_t> sendData = Azure::Core::Http::WebSockets::_detail::GenerateRandomBytes(N);
socket.SendFrame(sendData, true);
std::vector<uint8_t> receiveData;
std::shared_ptr<WebSocketFrame> response;
do
{
response = socket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
auto binaryResult = response->AsBinaryFrame();
receiveData.insert(receiveData.end(), binaryResult->Data.begin(), binaryResult->Data.end());
} while (!response->IsFinalFrame);
// Make sure we get back the data we sent in the echo request.
EXPECT_EQ(sendData.size(), receiveData.size());
EXPECT_EQ(sendData, receiveData);
}
TEST_F(WebSocketTests, VariableSizeEcho)
{
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
testSocket.Open();
{
EchoRandomData<100>(testSocket);
EchoRandomData<124>(testSocket);
EchoRandomData<125>(testSocket);
// The websocket protocol treats lengths of 125, 126 and > 127 specially.
EchoRandomData<126>(testSocket);
EchoRandomData<127>(testSocket);
EchoRandomData<128>(testSocket);
EchoRandomData<1020>(testSocket); // 1K-4
EchoRandomData<1021>(testSocket); // 1K-3
EchoRandomData<1022>(testSocket); // 1K-2
EchoRandomData<1023>(testSocket); // 1K-1
EchoRandomData<1024>(testSocket); // 1K
EchoRandomData<2048>(testSocket); // 2K
EchoRandomData<4096>(testSocket); // 4K
EchoRandomData<8192>(testSocket); // 8K
// The websocket protocol treats lengths of >65536 specially.
EchoRandomData<65535>(testSocket); // 64K-1
EchoRandomData<65536>(testSocket); // 64K
EchoRandomData<65537>(testSocket); // 64K+1
EchoRandomData<131072>(testSocket); // 128K
}
// Close the socket gracefully.
testSocket.Close();
}
}
// Generator for random bytes. Used in WebSocketImplementation and tests.
std::vector<uint8_t> GenerateRandomBytes(size_t index, size_t vectorSize)
{
std::random_device randomEngine;
std::vector<uint8_t> rv(vectorSize + 4);
rv[0] = index & 0xff;
rv[1] = (index >> 8) & 0xff;
rv[2] = (index >> 16) & 0xff;
rv[3] = (index >> 24) & 0xff;
std::generate(std::begin(rv) + 4, std::end(rv), [&randomEngine]() mutable {
return static_cast<uint8_t>(randomEngine() % UINT8_MAX);
});
return rv;
}
TEST_F(WebSocketTests, CloseDuringEcho)
{
{
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho"));
testSocket.Open();
testSocket.SendFrame("Test message", true);
auto response = testSocket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::PeerClosedReceived, response->FrameType);
auto PeerClosedReceived = response->AsPeerCloseFrame();
EXPECT_EQ(1001, PeerClosedReceived->RemoteStatusCode);
// Close the socket gracefully.
testSocket.Close();
}
// Close the websocket while a thread is waiting for a response.
{
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/echotest?delay=10"));
testSocket.Open();
std::thread testThread([&]() {
try
{
std::vector<uint8_t> sendData = GenerateRandomBytes(0, 100);
testSocket.SendFrame(sendData);
GTEST_LOG_(INFO) << "Receive frame.";
auto response = testSocket.ReceiveFrame();
GTEST_LOG_(INFO) << "Received frame.";
if (response->FrameType == WebSocketFrameType::PeerClosedReceived)
{
GTEST_LOG_(INFO) << "Peer closed the socket; Terminating thread.";
return;
}
else if (response->FrameType != WebSocketFrameType::BinaryFrameReceived)
{
GTEST_LOG_(INFO) << "Unexpected frame type received.";
}
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
auto binaryResult = response->AsBinaryFrame();
}
catch (Azure::Core::OperationCancelledException& ex)
{
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what()
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
}
catch (std::exception const& ex)
{
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
}
});
std::this_thread::sleep_for(100ms);
// Close the socket gracefully.
GTEST_LOG_(INFO) << "Closing Socket.";
EXPECT_NO_THROW(testSocket.Close(UndefinedButLegalCloseReason, "Close Reason."));
GTEST_LOG_(INFO) << "Closed Socket.";
testThread.join();
}
}
TEST_F(WebSocketTests, ExpectThrow)
{
{
WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho"));
EXPECT_THROW(testSocket.SendFrame("Foo", true), std::runtime_error);
std::vector<uint8_t> data{1, 2, 3, 4};
EXPECT_THROW(testSocket.SendFrame(data, true), std::runtime_error);
EXPECT_THROW(testSocket.ReceiveFrame(), std::runtime_error);
}
}
std::string ToHexString(std::vector<uint8_t> const& data)
{
std::stringstream ss;
for (auto const& byte : data)
{
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(byte);
}
return ss.str();
}
TEST_F(WebSocketTests, PingReceiveTest)
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
testSocket.Open();
if (!testSocket.HasBuiltInWebSocketSupport())
{
GTEST_LOG_(INFO) << "Sleeping for 15 seconds to collect pings.";
Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline(
Azure::DateTime{std::chrono::system_clock::now() + 15s});
EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException);
auto statistics = testSocket.GetStatistics();
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
EXPECT_NE(0, statistics.PingFramesReceived);
EXPECT_NE(0, statistics.PongFramesSent);
}
}
TEST_F(WebSocketTests, PingSendTest)
{
// Configure the socket to ping every second.
WebSocketOptions socketOptions;
socketOptions.PingInterval = std::chrono::seconds(1);
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"), socketOptions);
testSocket.Open();
if (!testSocket.HasBuiltInWebSocketSupport())
{
GTEST_LOG_(INFO) << "Sleeping for 10 seconds to collect pings.";
// Note that we cannot collect incoming pings or outgoing pongs unless we are receiving
// data from the server.
Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline(
Azure::DateTime{std::chrono::system_clock::now() + 10s});
EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException);
auto statistics = testSocket.GetStatistics();
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
EXPECT_NE(0, statistics.PingFramesSent);
EXPECT_NE(0, statistics.PongFramesReceived);
EXPECT_NE(0, statistics.PingFramesReceived);
EXPECT_NE(0, statistics.PongFramesSent);
}
}
TEST_F(WebSocketTests, MultiThreadedTestOnSingleSocket)
{
constexpr size_t threadCount = 50;
constexpr size_t testDataLength = 200000;
constexpr size_t testDataSize = 100;
constexpr auto testDuration = 10s;
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
testSocket.Open();
// seed test data for the operations.
std::vector<std::vector<uint8_t>> testData(testDataLength);
std::vector<std::vector<uint8_t>> receivedData(testDataLength);
std::atomic_size_t iterationCount(0);
// Spin up threadCount threads and hammer the echo server for 10 seconds.
std::vector<std::thread> threads;
std::atomic_int32_t cancellationExceptions{0};
std::atomic_int32_t exceptions{0};
for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1)
{
threads.push_back(std::thread([&]() {
std::chrono::time_point<std::chrono::system_clock> startTime
= std::chrono::system_clock::now();
// Set the context to expire *after* the test is supposed to finish.
Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline(
Azure::DateTime{startTime} + testDuration + 10s);
size_t iteration = 0;
try
{
do
{
iteration = iterationCount++;
std::vector<uint8_t> sendData = GenerateRandomBytes(iteration, testDataSize);
{
if (iteration < testData.size())
{
if (testData[iteration].size() != 0)
{
GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl;
}
EXPECT_EQ(0, testData[iteration].size());
testData[iteration] = sendData;
}
}
testSocket.SendFrame(sendData, true /*, context*/);
auto response = testSocket.ReceiveFrame(context);
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
auto binaryResult = response->AsBinaryFrame();
// Make sure we get back the data we sent in the echo request.
if (binaryResult->Data.size() == 0)
{
GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl;
}
EXPECT_EQ(sendData.size(), binaryResult->Data.size());
{
// There is no ordering expectation on the results, so we just remember the data
// as it comes in. We'll make sure we received everything later on.
if (iteration < receivedData.size())
{
if (receivedData[iteration].size() != 0)
{
GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration
<< std::endl;
}
EXPECT_EQ(0, receivedData[iteration].size());
receivedData[iteration] = binaryResult->Data;
}
}
} while (std::chrono::system_clock::now() - startTime < testDuration);
}
catch (Azure::Core::OperationCancelledException& ex)
{
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
cancellationExceptions++;
}
catch (std::exception const& ex)
{
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
exceptions++;
}
}));
}
// Wait for all the threads to exit.
for (auto& thread : threads)
{
thread.join();
}
// We no longer need to worry about synchronization since all the worker threads are done.
GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl;
GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex
<< testData.size() << ")" << std::endl;
EXPECT_GE(testDataLength, iterationCount.load());
auto statistics = testSocket.GetStatistics();
GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent;
GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived;
GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived;
GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent;
GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived;
GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent;
GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent;
GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived;
GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped;
GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads;
GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes;
// Close the socket gracefully.
testSocket.Close();
EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesSent);
EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesReceived);
// Resize the test data to the number of actual iterations.
testData.resize(iterationCount.load());
receivedData.resize(iterationCount.load());
// If we've processed every iteration, let's make sure that we received everything we sent.
// If we dropped some results, then we can't check to ensure that we have received everything
// because we can't account for everything sent.
std::multiset<std::string> testDataStrings;
std::multiset<std::string> receivedDataStrings;
for (auto const& data : testData)
{
testDataStrings.emplace(ToHexString(data));
}
for (auto const& data : receivedData)
{
receivedDataStrings.emplace(ToHexString(data));
}
EXPECT_EQ(testDataStrings, receivedDataStrings);
for (auto const& data : testDataStrings)
{
if (receivedDataStrings.count(data) != testDataStrings.count(data))
{
GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data)
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
<< " Missing Data: " << data << std::endl;
}
EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data));
}
for (auto const& data : receivedDataStrings)
{
if (testDataStrings.count(data) != receivedDataStrings.count(data))
{
GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data)
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
<< " Missing Data: " << data << std::endl;
}
EXPECT_NE(testDataStrings.end(), testDataStrings.find(data));
}
// We shouldn't have seen any exceptions during the run.
EXPECT_EQ(0, exceptions.load());
EXPECT_EQ(0, cancellationExceptions.load());
}
TEST_F(WebSocketTests, MultiThreadedTestOnMultipleSockets)
{
constexpr size_t threadCount = 50;
constexpr size_t testDataLength = 200000;
constexpr size_t testDataSize = 100;
constexpr auto testDuration = 10s;
// seed test data for the operations.
std::vector<std::vector<uint8_t>> testData(testDataLength);
std::vector<std::vector<uint8_t>> receivedData(testDataLength);
std::atomic_size_t iterationCount(0);
// Spin up threadCount threads and hammer the echo server for 10 seconds.
std::vector<std::thread> threads;
std::atomic_int32_t cancellationExceptions{0};
std::atomic_int32_t exceptions{0};
for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1)
{
threads.push_back(std::thread([&]() {
std::chrono::time_point<std::chrono::system_clock> startTime
= std::chrono::system_clock::now();
// Set the context to expire *after* the test is supposed to finish.
Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline(
Azure::DateTime{startTime} + testDuration + 10s);
size_t iteration = 0;
try
{
WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"));
testSocket.Open();
do
{
iteration = iterationCount++;
std::vector<uint8_t> sendData = GenerateRandomBytes(iteration, testDataSize);
{
if (iteration < testData.size())
{
if (testData[iteration].size() != 0)
{
GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl;
}
EXPECT_EQ(0, testData[iteration].size());
testData[iteration] = sendData;
}
}
testSocket.SendFrame(sendData, true /*, context*/);
auto response = testSocket.ReceiveFrame(context);
EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType);
auto binaryResult = response->AsBinaryFrame();
// Make sure we get back the data we sent in the echo request.
if (binaryResult->Data.size() == 0)
{
GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl;
}
EXPECT_EQ(sendData.size(), binaryResult->Data.size());
{
// There is no ordering expectation on the results, so we just remember the data
// as it comes in. We'll make sure we received everything later on.
if (iteration < receivedData.size())
{
if (receivedData[iteration].size() != 0)
{
GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration
<< std::endl;
}
EXPECT_EQ(0, receivedData[iteration].size());
receivedData[iteration] = binaryResult->Data;
}
}
} while (std::chrono::system_clock::now() - startTime < testDuration);
// Close the socket gracefully.
testSocket.Close();
}
catch (Azure::Core::OperationCancelledException& ex)
{
GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration
<< " Current Thread: " << std::this_thread::get_id() << std::endl;
cancellationExceptions++;
}
catch (std::exception const& ex)
{
GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl;
exceptions++;
}
}));
}
// Wait for all the threads to exit.
for (auto& thread : threads)
{
thread.join();
}
// We no longer need to worry about synchronization since all the worker threads are done.
GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl;
GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex
<< testData.size() << ")" << std::endl;
EXPECT_GE(testDataLength, iterationCount.load());
// Resize the test data to the number of actual iterations.
testData.resize(iterationCount.load());
receivedData.resize(iterationCount.load());
// If we've processed every iteration, let's make sure that we received everything we sent.
// If we dropped some results, then we can't check to ensure that we have received everything
// because we can't account for everything sent.
std::multiset<std::string> testDataStrings;
std::multiset<std::string> receivedDataStrings;
for (auto const& data : testData)
{
testDataStrings.emplace(ToHexString(data));
}
for (auto const& data : receivedData)
{
receivedDataStrings.emplace(ToHexString(data));
}
EXPECT_EQ(testDataStrings, receivedDataStrings);
for (auto const& data : testDataStrings)
{
if (receivedDataStrings.count(data) != testDataStrings.count(data))
{
GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data)
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
<< " Missing Data: " << data << std::endl;
}
EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data));
}
for (auto const& data : receivedDataStrings)
{
if (testDataStrings.count(data) != receivedDataStrings.count(data))
{
GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data)
<< " ReceivedDataCount: " << receivedDataStrings.count(data)
<< " Missing Data: " << data << std::endl;
}
EXPECT_NE(testDataStrings.end(), testDataStrings.find(data));
}
// We shouldn't have seen any exceptions during the run.
EXPECT_EQ(0, exceptions.load());
EXPECT_EQ(0, cancellationExceptions.load());
}
// Does not work because curl rejects the wss: scheme.
class LibWebSocketIncrementProtocol {
WebSocketOptions m_options{{"dumb-increment-protocol"}};
WebSocket m_socket;
public:
LibWebSocketIncrementProtocol() : m_socket{Azure::Core::Url("wss://libwebsockets.org"), m_options}
{
}
void Open() { m_socket.Open(); }
int GetNextNumber()
{
// Time out in 5 seconds if no activity.
Azure::Core::Context contextWithTimeout
= Azure::Core::Context().WithDeadline(std::chrono::system_clock::now() + 10s);
auto work = m_socket.ReceiveFrame(contextWithTimeout);
if (work->FrameType == WebSocketFrameType::TextFrameReceived)
{
auto frame = work->AsTextFrame();
return std::atoi(frame->Text.c_str());
}
if (work->FrameType == WebSocketFrameType::BinaryFrameReceived)
{
auto frame = work->AsBinaryFrame();
throw std::runtime_error("Not implemented");
}
else if (work->FrameType == WebSocketFrameType::PeerClosedReceived)
{
GTEST_LOG_(INFO) << "Remote server closed connection." << std::endl;
throw std::runtime_error("Remote server closed connection.");
}
else
{
throw std::runtime_error("Unknown result type");
}
}
void Reset() { m_socket.SendFrame("reset\n", true); }
void RequestClose() { m_socket.SendFrame("closeme\n", true); }
void Close() { m_socket.Close(); }
void Close(uint16_t closeCode, std::string const& reasonText = {})
{
m_socket.Close(closeCode, reasonText);
}
void ConsumeUntilClosed()
{
while (m_socket.IsOpen())
{
auto work = m_socket.ReceiveFrame();
if (work->FrameType == WebSocketFrameType::PeerClosedReceived)
{
auto peerClose = work->AsPeerCloseFrame();
GTEST_LOG_(INFO) << "Peer closed. Remote Code: " << std::dec << peerClose->RemoteStatusCode
<< " (0x" << std::hex << peerClose->RemoteStatusCode << ")" << std::endl;
if (!peerClose->RemoteCloseReason.empty())
{
GTEST_LOG_(INFO) << " Peer Closed Data: " << peerClose->RemoteCloseReason;
}
GTEST_LOG_(INFO) << std::endl;
return;
}
else if (work->FrameType == WebSocketFrameType::TextFrameReceived)
{
auto frame = work->AsTextFrame();
GTEST_LOG_(INFO) << "Ignoring " << frame->Text << std::endl;
}
}
}
};
class LibWebSocketStatus {
public:
std::string GetLWSStatus()
{
WebSocketOptions options;
options.ServiceName = "websockettest";
// Send 3 protocols to LWS.
options.Protocols.push_back("brownCow");
options.Protocols.push_back("lws-status");
options.Protocols.push_back("flibbityflobbidy");
WebSocket serverSocket(Azure::Core::Url("wss://libwebsockets.org"), options);
serverSocket.Open();
// The server should have chosen the lws-status protocol since it doesn't understand the other
// protocols.
EXPECT_EQ("lws-status", serverSocket.GetNegotiatedProtocol());
std::string returnValue;
std::shared_ptr<WebSocketFrame> lwsStatus;
do
{
lwsStatus = serverSocket.ReceiveFrame();
EXPECT_EQ(WebSocketFrameType::TextFrameReceived, lwsStatus->FrameType);
if (lwsStatus->FrameType == WebSocketFrameType::TextFrameReceived)
{
auto textFrame = lwsStatus->AsTextFrame();
returnValue.insert(returnValue.end(), textFrame->Text.begin(), textFrame->Text.end());
}
} while (!lwsStatus->IsFinalFrame);
serverSocket.Close();
return returnValue;
}
};
TEST_F(WebSocketTests, LibWebSocketOrgLwsStatus)
{
{
LibWebSocketStatus lwsStatus;
auto serverStatus = lwsStatus.GetLWSStatus();
GTEST_LOG_(INFO) << "Server status: " << serverStatus << std::endl;
Azure::Core::Json::_internal::json status;
EXPECT_NO_THROW(status = Azure::Core::Json::_internal::json::parse(serverStatus));
EXPECT_TRUE(status["conns"].is_array());
auto& connections = status["conns"].get_ref<std::vector<Azure::Core::Json::_internal::json>&>();
bool foundOurConnection = false;
// Scan through the list of connections to find a connection from the websockettest.
for (auto& connection : connections)
{
EXPECT_TRUE(connection["ua"].is_string());
auto userAgent = connection["ua"].get<std::string>();
if (userAgent.find("websockettest") != std::string::npos)
{
foundOurConnection = true;
break;
}
}
EXPECT_TRUE(foundOurConnection);
}
}
TEST_F(WebSocketTests, LibWebSocketOrgIncrement)
{
{
LibWebSocketIncrementProtocol incrementProtocol;
incrementProtocol.Open();
// Note that we cannot practically validate the numbers received from the service because
// they may be in flight at the time the "Reset" call is made.
for (auto i = 0; i < 100; i += 1)
{
if (i % 5 == 0)
{
GTEST_LOG_(INFO) << "Reset" << std::endl;
incrementProtocol.Reset();
}
int number = incrementProtocol.GetNextNumber();
GTEST_LOG_(INFO) << "Got next number " << number << std::endl;
}
incrementProtocol.RequestClose();
incrementProtocol.ConsumeUntilClosed();
}
}
#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER)
TEST_F(WebSocketTests, CurlTransportCoverage)
{
{
Azure::Core::Http::WebSockets::CurlWebSocketTransportOptions transportOptions;
transportOptions.HttpKeepAlive = false;
auto transport
= std::make_shared<Azure::Core::Http::WebSockets::CurlWebSocketTransport>(transportOptions);
EXPECT_THROW(transport->NativeCloseSocket(1001, {}, {}), std::runtime_error);
EXPECT_THROW(transport->NativeGetCloseSocketInformation({}), std::runtime_error);
EXPECT_THROW(
transport->NativeSendFrame(WebSocketTransport::NativeWebSocketFrameType::Binary, {}, {}),
std::runtime_error);
EXPECT_THROW(transport->NativeReceiveFrame({}), std::runtime_error);
}
}
#endif