diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index ad83b71d0..207720bcf 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -43,18 +43,22 @@ if(BUILD_TRANSPORT_CURL) src/http/curl/curl_connection_pool_private.hpp src/http/curl/curl_connection_private.hpp src/http/curl/curl_session_private.hpp + src/http/curl/curl_websockets.cpp ) SET(CURL_TRANSPORT_ADAPTER_INC inc/azure/core/http/curl_transport.hpp + inc/azure/core/http/websockets/curl_websockets_transport.hpp ) endif() if(BUILD_TRANSPORT_WINHTTP) SET(WIN_TRANSPORT_ADAPTER_SRC src/http/winhttp/win_http_transport.cpp src/http/winhttp/win_http_request.hpp + src/http/winhttp/win_http_websockets.cpp ) SET(WIN_TRANSPORT_ADAPTER_INC inc/azure/core/http/win_http_transport.hpp + inc/azure/core/http/websockets/win_http_websockets_transport.hpp ) endif() @@ -80,6 +84,8 @@ set( inc/azure/core/http/policies/policy.hpp inc/azure/core/http/raw_response.hpp inc/azure/core/http/transport.hpp + inc/azure/core/http/websockets/websockets.hpp + inc/azure/core/http/websockets/websockets_transport.hpp inc/azure/core/internal/client_options.hpp inc/azure/core/internal/contract.hpp inc/azure/core/internal/credentials/authorization_challenge_parser.hpp @@ -142,6 +148,8 @@ set( src/http/transport_policy.cpp src/http/url.cpp src/http/user_agent.cpp + src/http/websockets/websockets.cpp + src/http/websockets/websockets_impl.cpp src/io/body_stream.cpp src/io/random_access_file_body_stream.cpp src/logger.cpp diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp new file mode 100644 index 000000000..d3b1ecb1a --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via CURL. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/curl_transport.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + struct CurlWebSocketTransportOptions : public Azure::Core::Http::CurlTransportOptions + { + }; + /** + * @brief Concrete implementation of a WebSocket Transport that uses libcurl. + */ + class CurlWebSocketTransport : public CurlTransport, public WebSocketTransport { + public: + /** + * @brief Construct a new CurlWebSocketTransport object. + * + * @param options Optional parameter to override the default options. + */ + CurlWebSocketTransport( + CurlWebSocketTransportOptions const& options = CurlWebSocketTransportOptions()) + : CurlTransport(options) + { + } + + /** + * @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse + * + * @param request an HTTP Request to be send. + * @param context A context to control the request lifetime. + * + * @return unique ptr to an HTTP RawResponse. + */ + virtual std::unique_ptr Send(Request& request, Context const& context) override; + + /** + * @brief Indicates if the transport natively supports websockets or not. + * + * @details For the CURL websocket transport, the transport does NOT support native websockets - + * it is the responsibility of the client of the WebSocketTransport to format WebSocket protocol + * elements. + */ + virtual bool HasBuiltInWebSocketSupport() override { return false; } + + /** + * @brief Closes the WebSocket handle. + * + */ + virtual void Close() override; + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * The first param is the close reason, the second is descriptive text. + */ + virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&) + override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief Retrieve the status of the close socket operation. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + const Azure::Core::Context&) override + { + throw std::runtime_error("Not implemented"); + } + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType, + std::vector const&, + Azure::Core::Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief Receive a frame of data from the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const&) override + { + throw std::runtime_error("Not implemented"); + } + + // Non-Native WebSocket support. + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or until + * there is no more data to get from the socket. + * + * @param buffer Buffer to fill with data. + * @param bufferSize Size of buffer. + * @param context Context to control the request lifetime. + * + * @returns Buffer data received. + * + */ + virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) + override; + + /** + * @brief This method will use libcurl socket to write all the bytes from buffer. + * + * @param buffer Buffer to send. + * @param bufferSize Number of bytes to write. + * @param context Context for the operation. + */ + virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) + override; + + /** + * @brief returns true if this transport supports WebSockets, false otherwise. + */ + bool HasWebSocketSupport() const override { return true; } + + private: + // std::unique_ptr cannot be constructed on an incomplete type (CurlNetworkConnection), but + // std::shared_ptr can be. + std::shared_ptr m_upgradedConnection; + void OnUpgradedConnection( + std::unique_ptr&& upgradedConnection) override; + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp new file mode 100644 index 000000000..a0a55d3bd --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Azure Core APIs implementing the WebSocket protocol [RFC 6455] + * (https://www.rfc-editor.org/rfc/rfc6455.html). + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/internal/client_options.hpp" +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + namespace _detail { + class WebSocketImplementation; + } + namespace _internal { + + enum class WebSocketFrameType : int + { + Unknown, + TextFrameReceived, + BinaryFrameReceived, + PeerClosedReceived, + }; + + enum class WebSocketErrorCode : uint16_t + { + OK = 1000, + EndpointDisappearing = 1001, + ProtocolError = 1002, + UnknownDataType = 1003, + Reserved1 = 1004, + NoStatusCodePresent = 1005, + ConnectionClosedWithoutCloseFrame = 1006, + InvalidMessageData = 1007, + PolicyViolation = 1008, + MessageTooLarge = 1009, + ExtensionNotFound = 1010, + UnexpectedError = 1011, + TlsHandshakeFailure = 1015, + }; + + class WebSocketTextFrame; + class WebSocketBinaryFrame; + class WebSocketPeerCloseFrame; + + namespace _detail { + class WebSocketImplementation; + } + /** @brief Statistics about data sent and received by the WebSocket. + * + * @remarks This class is primarily intended for test collateral and debugging to allow + * a caller to determine information about the status of a WebSocket. + * + * Note: Some of these statistics are not available if the underlying transport supports native + * websockets. + */ + struct WebSocketStatistics + { + /** @brief The number of WebSocket frames sent on this WebSocket. */ + uint32_t FramesSent; + + /** @brief The number of bytes of data sent to the peer on this WebSocket. */ + uint32_t BytesSent; + + /** @brief The number of WebSocket frames received from the peer. */ + uint32_t FramesReceived; + + /** @brief The number of bytes received from the peer. */ + uint32_t BytesReceived; + + /** @brief The number of "Ping" frames received from the peer. */ + uint32_t PingFramesReceived; + + /** @brief The number of "Ping" frames sent to the peer. */ + uint32_t PingFramesSent; + + /** @brief The number of "Pong" frames received from the peer. */ + uint32_t PongFramesReceived; + + /** @brief The number of "Pong" frames sent to the peer. */ + uint32_t PongFramesSent; + + /** @brief The number of "Text" frames received from the peer. */ + uint32_t TextFramesReceived; + + /** @brief The number of "Text" frames sent to the peer. */ + uint32_t TextFramesSent; + + /** @brief The number of "Binary" frames received from the peer. */ + uint32_t BinaryFramesReceived; + + /** @brief The number of "Binary" frames sent to the peer. */ + uint32_t BinaryFramesSent; + + /** @brief The number of "Continuation" frames sent to the peer. */ + uint32_t ContinuationFramesSent; + + /** @brief The number of "Continuation" frames received from the peer. */ + uint32_t ContinuationFramesReceived; + + /** @brief The number of "Close" frames received from the peer. */ + uint32_t CloseFramesReceived; + + /** @brief The number of frames received which were not processed. */ + uint32_t FramesDropped; + + /** @brief The number of frames received which were not returned because they were received + * after the Close() method was called. */ + + uint32_t FramesDroppedByClose; + /** @brief The number of frames dropped because they were over the maximum payload size. */ + + uint32_t FramesDroppedByPayloadSizeLimit; + /** @brief The number of frames dropped because they were out of compliance with the protocol. + */ + uint32_t FramesDroppedByProtocolError; + + /** @brief The number of reads performed on the transport.*/ + uint32_t TransportReads; + + /** @brief The number of bytes read from the transport. */ + uint32_t TransportReadBytes; + }; + + /** @brief A frame of data received from a WebSocket. + */ + class WebSocketFrame { + public: + /** @brief The type of frame received: Text, Binary or Close. */ + WebSocketFrameType FrameType{}; + + /** @brief True if the frame received is a "final" frame */ + bool IsFinalFrame{false}; + + /** @brief Returns the contents of the frame as a Text frame. + * @returns A WebSocketTextFrame containing the contents of the frame. + */ + std::shared_ptr AsTextFrame(); + + /** @brief Returns the contents of the frame as a Binary frame. + * @returns A WebSocketBinaryFrame containing the contents of the frame. + */ + + std::shared_ptr AsBinaryFrame(); + /** @brief Returns the contents of the frame as a Peer Close frame. + * @returns A WebSocketPeerCloseFrame containing the contents of the frame. + */ + std::shared_ptr AsPeerCloseFrame(); + + /** @brief Construct a new instance of a WebSocketFrame.*/ + WebSocketFrame() = default; + + /** @brief Construct a new instance of a WebSocketFrame with a specific frame type. + * @param frameType The type of frame received. + */ + WebSocketFrame(WebSocketFrameType frameType) : FrameType{frameType} {} + + /** @brief Construct a new instance of a WebSocketFrame with a specific frame type and final + * flag. + * @param frameType The type of frame received. + * @param isFinalFrame true if the frame is the final frame. + */ + WebSocketFrame(WebSocketFrameType frameType, bool isFinalFrame) + : FrameType{frameType}, IsFinalFrame{isFinalFrame} + { + } + }; + + /** @brief Contains the contents of a WebSocket Text frame.*/ + class WebSocketTextFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + private: + public: + /** @brief Constructs a new WebSocketTextFrame */ + WebSocketTextFrame() : WebSocketFrame(WebSocketFrameType::TextFrameReceived){}; + + /** @brief Text of the frame received from the remote peer. */ + std::string Text; + + private: + /** @brief Constructs a new WebSocketTextFrame + * @param isFinalFrame True if this is the final frame in a multi-frame message. + * @param body UTF-8 encoded text of the frame data. + * @param size Length in bytes of the frame body. + */ + WebSocketTextFrame(bool isFinalFrame, uint8_t const* body, size_t size) + : WebSocketFrame{WebSocketFrameType::TextFrameReceived, isFinalFrame}, + Text(body, body + size) + { + } + }; + + /** @brief Contains the contents of a WebSocket Binary frame.*/ + class WebSocketBinaryFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + private: + public: + /** @brief Constructs a new WebSocketBinaryFrame */ + WebSocketBinaryFrame() : WebSocketFrame(WebSocketFrameType::BinaryFrameReceived){}; + + /** @brief Binary frame data received from the remote peer. */ + std::vector Data; + + /** @brief Constructs a new WebSocketBinaryFrame + * @param isFinal True if this is the final frame in a multi-frame message. + * @param body binary of the frame data. + * @param size Length in bytes of the frame body. + */ + private: + WebSocketBinaryFrame(bool isFinal, uint8_t const* body, size_t size) + : WebSocketFrame{WebSocketFrameType::BinaryFrameReceived, isFinal}, + Data(body, body + size) + { + } + }; + + /** @brief Contains the contents of a WebSocket Close frame.*/ + class WebSocketPeerCloseFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + public: + /** @brief Constructs a new WebSocketPeerCloseFrame */ + WebSocketPeerCloseFrame() : WebSocketFrame(WebSocketFrameType::PeerClosedReceived){}; + + /** @brief Status code sent from the remote peer. Typically a member of the WebSocketErrorCode + * enumeration */ + uint16_t RemoteStatusCode{}; + + /** @brief Optional text sent from the remote peer. */ + std::string RemoteCloseReason; + + private: + /** @brief Constructs a new WebSocketBinaryFrame + * @param remoteStatusCode Status code sent by the remote peer. + * @param remoteCloseReason Optional reason sent by the remote peer. + */ + WebSocketPeerCloseFrame(uint16_t remoteStatusCode, std::string const& remoteCloseReason) + : WebSocketFrame{WebSocketFrameType::PeerClosedReceived}, + RemoteStatusCode(remoteStatusCode), RemoteCloseReason(remoteCloseReason) + { + } + }; + + struct WebSocketOptions : Azure::Core::_internal::ClientOptions + { + /** + * @brief The set of protocols which are supported by this client + */ + std::vector Protocols = {}; + + /** + * @brief The protocol name of the service client. Used for the User-Agent header + * in the initial WebSocket handshake. + */ + std::string ServiceName; + + /** + * @brief The version of the service client. Used for the User-Agent header in the + * initial WebSocket handshake + */ + std::string ServiceVersion; + + /** + * @brief The period of time between ping operations, default is 60 seconds. + */ + std::chrono::duration PingInterval{std::chrono::seconds{60}}; + + /** + * @brief Construct an instance of a WebSocketOptions type. + * + * @param protocols Supported protocols for this websocket client. + */ + explicit WebSocketOptions(std::vector protocols) + : Azure::Core::_internal::ClientOptions{}, Protocols(protocols) + { + } + WebSocketOptions() = default; + }; + + class WebSocket { + public: + /** @brief Constructs a new instance of a WebSocket with the specified WebSocket options. + * + * @param remoteUrl The URL of the remote WebSocket server. + * @param options The options to use for the WebSocket. + */ + explicit WebSocket( + Azure::Core::Url const& remoteUrl, + WebSocketOptions const& options = WebSocketOptions{}); + + /** @brief Destroys an instance of a WebSocket. + */ + ~WebSocket(); + + /** @brief Opens a WebSocket connection to a remote server. + * + * @param context Context for the operation, used for cancellation and timeout. + */ + void Open(Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Closes a WebSocket connection to the remote server gracefully. + * + * @param context Context for the operation. + */ + void Close(Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Closes a WebSocket connection to the remote server with additional context. + * + * @param closeStatus 16 bit WebSocket error code. + * @param closeReason String describing the reason for closing the socket. + * @param context Context for the operation. + */ + void Close( + uint16_t closeStatus, + std::string const& closeReason = {}, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Sends a String frame to the remote server. + * + * @param textFrame UTF-8 encoded text to send. + * @param isFinalFrame if True, this is the final frame in a multi-frame message. + * @param context Context for the operation. + */ + void SendFrame( + std::string const& textFrame, + bool isFinalFrame = false, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Sends a Binary frame to the remote server. + * + * @param binaryFrame Binary data to send. + * @param isFinalFrame if True, this is the final frame in a multi-frame message. + * @param context Context for the operation. + */ + void SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame = false, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Receive a frame from the remote server. + * + * @param context Context for the operation. + * + * @returns The received WebSocket frame. + * + */ + std::shared_ptr ReceiveFrame( + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief AddHeader - Adds a header to the initial handshake. + * + * @note This API is ignored after the WebSocket is opened. + * + * @param headerName Name of header to add to the initial handshake request. + * @param headerValue Value of header to add. + */ + void AddHeader(std::string const& headerName, std::string const& headerValue); + + /** @brief Determine if the WebSocket is open. + * + * @returns true if the WebSocket is open, false otherwise. + */ + bool IsOpen() const; + + /** @brief Returns "true" if the configured websocket transport + * supports websockets in the transport, or if the websocket implementation + * is providing websocket protocol support. + * + * @returns true if the HTTP transport used for WebSocket support directly supports the + * WebSocket API. + */ + bool HasBuiltInWebSocketSupport() const; + + /** @brief Returns the protocol chosen by the remote server during the initial handshake. + * + * @returns The protocol negotiated between client and server. + */ + std::string const& GetNegotiatedProtocol() const; + + /** @brief Returns statistics about the WebSocket. + * + * @returns The statistics about the WebSocket. + */ + WebSocketStatistics GetStatistics() const; + + private: + std::unique_ptr + m_socketImplementation; + }; + } // namespace _internal +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp new file mode 100644 index 000000000..1afda3c0d --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Utilities to be used by HTTP WebSocket transport implementations. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + /** + * @brief Base class for all WebSocket transport implementations. + */ + class WebSocketTransport { + public: + /** + * @brief Web Socket Frame type, one of Text or Binary. + */ + enum class NativeWebSocketFrameType + { + /** + * @brief Indicates that the frame is a partial UTF-8 encoded text frame - it is NOT the + * complete frame to be sent to the remote node. + */ + TextFragment, + /** + * @brief Indicates that the frame is either the complete UTF-8 encoded text frame to be sent + * to the remote node or the final frame of a multipart message. + */ + Text, + /** + * @brief Indicates that the frame is either the complete binary frame to be sent + * to the remote node or the final frame of a multipart message. + */ + Binary, + /** + * @brief Indicates that the frame is a partial binary frame - it is NOT the + * complete frame to be sent to the remote node. + */ + BinaryFragment, + + /** + * @brief Indicates that the frame is a "close" frame - the remote node + * sent a close frame. + */ + Closed, + }; + + /** @brief Close information returned from a WebSocket transport that has builtin support + * for WebSockets. + */ + struct NativeWebSocketCloseInformation + { + /** + * @brief Close response code. + */ + uint16_t CloseReason; + /** + * @brief Close reason. + */ + std::string CloseReasonDescription; + }; + /** @brief Frame information returned from a WebSocket transport that has builtin support + * for WebSockets. + */ + struct NativeWebSocketReceiveInformation + { + /** + * @brief Type of frame received. + */ + NativeWebSocketFrameType FrameType; + /** + * @brief Data received. + */ + std::vector FrameData; + }; + /** + * @brief Destructs `%WebSocketTransport`. + * + */ + virtual ~WebSocketTransport() {} + + /** + * @brief Indicates whether the transport natively supports WebSockets. + * + * @returns true if the transport has native websocket support, false otherwise. + */ + virtual bool HasBuiltInWebSocketSupport() = 0; + + /** + * @brief Closes the WebSocket. + * + * Does not notify the remote endpoint that the socket is being closed. + * + */ + virtual void Close() = 0; + + /**************/ + /* Native WebSocket support functions*/ + /**************/ + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + */ + virtual void NativeCloseSocket( + uint16_t status, + std::string const& disconnectReason, + Azure::Core::Context const& context) + = 0; + + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + Azure::Core::Context const& context) + = 0; + + /** + * @brief Send a frame of data to the remote node. + * + * @param frameType Frame type sent to the server, Text or Binary. + * @param frameData Frame data to be sent to the server. + * @param context Context for the operation. + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType frameType, + std::vector const& frameData, + Azure::Core::Context const& context) + = 0; + + /** + * @brief Receive a frame from the remote WebSocket server. + * + * @param context Context for the operation. + * + * @returns a tuple containing the Frame data received from the remote server and the type of + * data returned from the remote endpoint + */ + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const& context) + = 0; + + /**************/ + /* Non Native WebSocket support functions */ + /**************/ + + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or until + * there is no more data to get from the socket. + * + */ + virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) = 0; + + /** + * @brief This method will use the raw socket to write all the bytes from buffer. + * + */ + virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) = 0; + + protected: + /** + * @brief Constructs a default instance of `%WebSocketTransport`. + * + */ + WebSocketTransport() = default; + + /** + * @brief Constructs `%HttpTransport` by copying another instance of `%HttpTransport`. + * + * @param other An instance to copy. + */ + WebSocketTransport(const WebSocketTransport& other) = default; + + /** + * @brief Constructs a WebSocketTransport from another WebSocketTransport. + * + * @param other An instance to move in. + */ + WebSocketTransport(WebSocketTransport&& other) = default; + + /** + * @brief Assigns one WebSocketTransport to another. + * + * @param other An instance to assign. + * + * @return A reference to this instance. + */ + WebSocketTransport& operator=(const WebSocketTransport& other) = default; + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp new file mode 100644 index 000000000..8fdf5b533 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via WInHTTP. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include "azure/core/http/win_http_transport.hpp" +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + /** + * @brief Concrete implementation of a WebSocket Transport that uses WinHTTP. + */ + class WinHttpWebSocketTransport : public WebSocketTransport, public WinHttpTransport { + + Azure::Core::Http::_detail::unique_HINTERNET m_socketHandle; + std::mutex m_sendMutex; + std::mutex m_receiveMutex; + + // Called by the + void OnUpgradedConnection( + Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) override; + + public: + /** + * @brief Construct a new WinHTTP WebSocket Transport. + * + * @param options Optional parameter to override the default options. + */ + WinHttpWebSocketTransport(WinHttpTransportOptions const& options = WinHttpTransportOptions()) + : WinHttpTransport(options) + { + } + + /** + * @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse + * + * @param request an HTTP Request to be send. + * @param context A context to control the request lifetime. + * + * @return unique ptr to an HTTP RawResponse. + */ + virtual std::unique_ptr Send(Request& request, Context const& context) override; + + /** + * @brief Indicates if the transports natively websockets or not. + * + * @details For the WinHTTP websocket transport, the WinHTTP API supports websockets. + */ + virtual bool HasBuiltInWebSocketSupport() override { return true; } + + /** + * @brief Close the underlying WebSocket handle. + * + */ + virtual void Close() override; + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + * + */ + virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&) + override; + + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + Azure::Core::Context const& context) override; + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native + * websockets. + * + * @brief frameType Frame type sent to the server, Text or Binary. + * @brief frameData Frame data to be sent to the server. + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType, + std::vector const&, + Azure::Core::Context const&) override; + + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const&) override; + + // Non-Native WebSocket support. + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or + * until there is no more data to get from the socket. + * + * @details Not implemented for WinHTTP websockets because WinHTTP implements websockets + * natively. + */ + virtual size_t ReadFromSocket(uint8_t*, size_t, Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief This method will use sockets to write all the bytes from buffer. + * + * @details Not implemented for WinHTTP websockets because WinHTTP implements websockets + * natively. + * + */ + virtual int SendBuffer(uint8_t const*, size_t, Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + bool HasWebSocketSupport() const override { return true; } + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/src/http/curl/curl_websockets.cpp b/sdk/core/azure-core/src/http/curl/curl_websockets.cpp new file mode 100644 index 000000000..49e182cae --- /dev/null +++ b/sdk/core/azure-core/src/http/curl/curl_websockets.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/http.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/platform.hpp" + +// Private include +#include "curl_connection_private.hpp" + +#if defined(AZ_PLATFORM_POSIX) +#include // for poll() +#include // for socket shutdown +#elif defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include // for WSAPoll(); +#endif + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + void CurlWebSocketTransport::Close() { m_upgradedConnection->Shutdown(); } + + // Send an HTTP request to the remote server. + std::unique_ptr CurlWebSocketTransport::Send( + Request& request, + Context const& context) + { + // CURL doesn't understand the ws and wss protocols, so change the URL to be http based. + std::string requestScheme(request.GetUrl().GetScheme()); + if (requestScheme == "wss" || requestScheme == "ws") + { + if (requestScheme == "wss") + { + request.GetUrl().SetScheme("https"); + } + else + { + request.GetUrl().SetScheme("http"); + } + } + return CurlTransport::Send(request, context); + } + + size_t CurlWebSocketTransport::ReadFromSocket( + uint8_t* buffer, + size_t bufferSize, + Context const& context) + { + return m_upgradedConnection->ReadFromSocket(buffer, bufferSize, context); + } + + /** + * @brief This method will use libcurl socket to write all the bytes from buffer. + * + */ + int CurlWebSocketTransport::SendBuffer( + uint8_t const* buffer, + size_t bufferSize, + Context const& context) + { + return m_upgradedConnection->SendBuffer(buffer, bufferSize, context); + } + + void CurlWebSocketTransport::OnUpgradedConnection( + std::unique_ptr&& upgradedConnection) + { + // Note that m_upgradedConnection is a std::shared_ptr. We define it as a std::shared_ptr + // because a std::shared_ptr can be declared on an incomplete type, while a std::unique_ptr + // cannot. + m_upgradedConnection = std::move(upgradedConnection); + } + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/src/http/websockets/websockets.cpp b/sdk/core/azure-core/src/http/websockets/websockets.cpp new file mode 100644 index 000000000..65102f31a --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/context.hpp" +#include "websockets_impl.hpp" + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _internal { + + WebSocket::WebSocket(Azure::Core::Url const& remoteUrl, WebSocketOptions const& options) + : m_socketImplementation( + std::make_unique( + remoteUrl, + options)) + + { + } + WebSocket::~WebSocket() {} + + void WebSocket::Open(Azure::Core::Context const& context) + { + m_socketImplementation->Open(context); + } + void WebSocket::Close(Azure::Core::Context const& context) + { + m_socketImplementation->Close( + static_cast(WebSocketErrorCode::EndpointDisappearing), {}, context); + } + void WebSocket::Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context) + { + m_socketImplementation->Close(closeStatus, closeReason, context); + } + + void WebSocket::SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + m_socketImplementation->SendFrame(textFrame, isFinalFrame, context); + } + + void WebSocket::SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + m_socketImplementation->SendFrame(binaryFrame, isFinalFrame, context); + } + + WebSocketStatistics WebSocket::GetStatistics() const + { + return m_socketImplementation->GetStatistics(); + } + + bool WebSocket::HasBuiltInWebSocketSupport() const + { + return m_socketImplementation->HasBuiltInWebSocketSupport(); + } + + std::shared_ptr WebSocket::ReceiveFrame(Azure::Core::Context const& context) + { + return m_socketImplementation->ReceiveFrame(context); + } + + void WebSocket::AddHeader(std::string const& headerName, std::string const& headerValue) + { + m_socketImplementation->AddHeader(headerName, headerValue); + } + std::string const& WebSocket::GetNegotiatedProtocol() const + { + return m_socketImplementation->GetNegotiatedProtocol(); + } + + bool WebSocket::IsOpen() const { return m_socketImplementation->IsOpen(); } + + std::shared_ptr WebSocketFrame::AsTextFrame() + { + if (FrameType != WebSocketFrameType::TextFrameReceived) + { + throw std::logic_error("Cannot cast to TextFrameReceived."); + } + return static_cast(this)->shared_from_this(); + } + + std::shared_ptr WebSocketFrame::AsBinaryFrame() + { + if (FrameType != WebSocketFrameType::BinaryFrameReceived) + { + throw std::logic_error("Cannot cast to BinaryFrameReceived."); + } + return static_cast(this)->shared_from_this(); + } + + std::shared_ptr WebSocketFrame::AsPeerCloseFrame() + { + if (FrameType != WebSocketFrameType::PeerClosedReceived) + { + throw std::logic_error("Cannot cast to PeerClose."); + } + return static_cast(this)->shared_from_this(); + } + +}}}}} // namespace Azure::Core::Http::WebSockets::_internal diff --git a/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp b/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp new file mode 100644 index 000000000..f7d43b0fc --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp @@ -0,0 +1,876 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT +#include "websockets_impl.hpp" +#include "azure/core/base64.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/internal/cryptography/sha_hash.hpp" +// SUPPORT_NATIVE_TRANSPORT indicates if WinHTTP should be compiled with native transport support +// or not. +// Note that this is primarily required to improve the code coverage numbers in the CI pipeline. +#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER) +#include "azure/core/http/websockets/win_http_websockets_transport.hpp" +#define SUPPORT_NATIVE_TRANSPORT 1 +#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#define SUPPORT_NATIVE_TRANSPORT 0 +#endif +#include "azure/core/internal/diagnostics/log.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail { + using namespace Azure::Core::Http::WebSockets::_internal; + using namespace Azure::Core::Diagnostics::_internal; + using namespace Azure::Core::Diagnostics; + using namespace std::chrono_literals; + + namespace { + std::string HexEncode(std::vector const& data, size_t length) + { + std::stringstream ss; + for (size_t i = 0; i < std::min(data.size(), length); i++) + { + ss << std::hex << std::setfill('0') << std::setw(2) << static_cast(data[i]); + } + return ss.str(); + } + } // namespace + + WebSocketImplementation::WebSocketImplementation( + Azure::Core::Url const& remoteUrl, + WebSocketOptions const& options) + : m_remoteUrl(remoteUrl), m_options(options), m_pingThread(this, m_options.PingInterval) + { + } + + void WebSocketImplementation::Open(Azure::Core::Context const& context) + { + if (m_state != SocketState::Invalid && m_state != SocketState::Closed) + { + throw std::runtime_error( + "Socket in unexpected state: " + std::to_string(static_cast(m_state))); + } + m_state = SocketState::Opening; + +#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER) + WinHttpTransportOptions transportOptions; + auto winHttpTransport + = std::make_shared( + transportOptions); + m_transport = std::static_pointer_cast(winHttpTransport); + m_options.Transport.Transport = std::static_pointer_cast(winHttpTransport); +#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) + CurlWebSocketTransportOptions transportOptions; + transportOptions.HttpKeepAlive = false; + auto curlWebSockets + = std::make_shared(transportOptions); + + m_transport = std::static_pointer_cast(curlWebSockets); + m_options.Transport.Transport = std::static_pointer_cast(curlWebSockets); +#endif + + std::vector> perCallPolicies{}; + std::vector> perRetryPolicies{}; + // If the caller has told us a service name, add the telemetry policy to the pipeline to add + // Create pipeline for WebSocket handshake + std::string serviceName = m_options.ServiceName.empty() ? "azure.core.websockets" : m_options.ServiceName; + std::string serviceVersion = m_options.ServiceVersion.empty() ? "1.0.0" : m_options.ServiceVersion; + + Azure::Core::Http::_internal::HttpPipeline openPipeline( + m_options, serviceName, serviceVersion, std::move(perRetryPolicies), std::move(perCallPolicies)); + + Azure::Core::Http::Request openSocketRequest( + Azure::Core::Http::HttpMethod::Get, m_remoteUrl, false); + + // Generate the random request key. Only used when the transport doesn't support websockets + // natively. + auto randomKey = GenerateRandomKey(); + auto encodedKey = Azure::Core::Convert::Base64Encode(randomKey); + if (!m_transport->HasBuiltInWebSocketSupport()) + { + // If the transport doesn't support WebSockets natively, set the standardized WebSocket + // upgrade headers. + openSocketRequest.SetHeader("Upgrade", "websocket"); + openSocketRequest.SetHeader("Connection", "upgrade"); + openSocketRequest.SetHeader("Sec-WebSocket-Version", "13"); + openSocketRequest.SetHeader("Sec-WebSocket-Key", encodedKey); + } + if (!m_options.Protocols.empty()) + { + + std::string protocols; + for (auto const& protocol : m_options.Protocols) + { + protocols += protocol; + protocols += ", "; + } + protocols = protocols.substr(0, protocols.size() - 2); + openSocketRequest.SetHeader("Sec-WebSocket-Protocol", protocols); + } + for (auto const& additionalHeader : m_headers) + { + openSocketRequest.SetHeader(additionalHeader.first, additionalHeader.second); + } + std::string remoteOrigin; + remoteOrigin = m_remoteUrl.GetScheme(); + remoteOrigin += "://"; + remoteOrigin += m_remoteUrl.GetHost(); + openSocketRequest.SetHeader("Origin", remoteOrigin); + + // Send the connect request to the WebSocket server. + auto response = openPipeline.Send(openSocketRequest, context); + + // Ensure that the server thinks we're switching protocols. If it doesn't, + // fail immediately. + if (response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::SwitchingProtocols) + { + throw Azure::Core::Http::TransportException("Unexpected handshake response"); + } + + // Prove that the server received this socket request. + auto& responseHeaders = response->GetHeaders(); + if (!m_transport->HasBuiltInWebSocketSupport()) + { + auto socketAccept(responseHeaders.find("Sec-WebSocket-Accept")); + if (socketAccept == responseHeaders.end()) + { + throw Azure::Core::Http::TransportException("Missing Sec-WebSocket-Accept header"); + } + // Verify that the WebSocket server received *this* open request. + else + { + VerifySocketAccept(encodedKey, socketAccept->second); + } + m_initialBodyStream = response->ExtractBodyStream(); + m_pingThread.Start(m_transport); + } + + // Remember the protocol that the client chose. + auto chosenProtocol = responseHeaders.find("Sec-WebSocket-Protocol"); + if (chosenProtocol != responseHeaders.end()) + { + m_chosenProtocol = chosenProtocol->second; + } + + m_state = SocketState::Open; + } + bool WebSocketImplementation::HasBuiltInWebSocketSupport() + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + return m_transport->HasBuiltInWebSocketSupport(); + } + + std::string const& WebSocketImplementation::GetNegotiatedProtocol() + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + return m_chosenProtocol; + } + + void WebSocketImplementation::AddHeader(std::string const& header, std::string const& headerValue) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Closed && m_state != SocketState::Invalid) + { + throw std::runtime_error("AddHeader can only be called on closed sockets."); + } + m_headers.emplace(std::make_pair(header, headerValue)); + } + + void WebSocketImplementation::Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context) + { + std::unique_lock lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + // If we're closing an already closed socket, we're done. + if (m_state == SocketState::Closed) + { + return; + } + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + m_state = SocketState::Closing; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeCloseSocket(closeStatus, closeReason.c_str(), context); + } + else +#endif + { + // Send a going away message to the server. + std::vector closePayload; + closePayload.push_back(closeStatus >> 8); + closePayload.push_back(closeStatus & 0xff); + closePayload.insert(closePayload.end(), closeReason.begin(), closeReason.end()); + std::vector closeFrame = EncodeFrame(SocketOpcode::Close, true, closePayload); + SendTransportBuffer(closeFrame, context); + + // Unlock the state mutex before waiting for the close response to be received. + lock.unlock(); + // Drain the incoming series of frames from the server. + // Note that there might be in-flight frames that were sent from the other end of the + // WebSocket that we don't care about any more (since we're closing the WebSocket). So + // drain those frames. + auto closeResponse = ReceiveTransportFrame(context); + while (closeResponse && closeResponse->Opcode != SocketOpcode::Close) + { + m_receiveStatistics.FramesDroppedByClose++; + Log::Write( + Logger::Level::Warning, + "Received unexpected frame during close. Opcode: " + + std::to_string(static_cast(closeResponse->Opcode))); + closeResponse = ReceiveTransportFrame(context); + } + + // Re-acquire the state lock once we've received the close response. + lock.lock(); + m_stateOwner = std::this_thread::get_id(); + } + // Close the socket - after this point, the m_transport is invalid. + m_pingThread.Shutdown(); + m_transport->Close(); + m_state = SocketState::Closed; + } + + void WebSocketImplementation::SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + std::vector utf8text(textFrame.begin(), textFrame.end()); + m_receiveStatistics.TextFramesSent++; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeSendFrame( + (isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Text + : WebSocketTransport::NativeWebSocketFrameType::TextFragment), + utf8text, + context); + } + else +#endif + { + std::vector sendFrame = EncodeFrame(SocketOpcode::TextFrame, isFinalFrame, utf8text); + SendTransportBuffer(sendFrame, context); + } + } + + void WebSocketImplementation::SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + m_receiveStatistics.BinaryFramesSent++; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeSendFrame( + (isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Binary + : WebSocketTransport::NativeWebSocketFrameType::BinaryFragment), + binaryFrame, + context); + } + else +#endif + { + std::vector sendFrame + = EncodeFrame(SocketOpcode::BinaryFrame, isFinalFrame, binaryFrame); + + SendTransportBuffer(sendFrame, context); + } + } + + std::shared_ptr WebSocketImplementation::ReceiveFrame( + Azure::Core::Context const& context) + { + std::unique_lock lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + if (m_state != SocketState::Open && m_state != SocketState::Closing) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + + // Unlock the state lock to allow other threads to run. If we don't, we might end up in in a + // situation where the server won't respond to the this client because all the client threads + // are blocked on the state lock. + lock.unlock(); + + std::shared_ptr frame; + // Loop until we receive an returnable incoming frame. + // If the incoming frame is returnable, we return the value from the frame. + while (true) + { + frame = ReceiveTransportFrame(context); + if (frame) + { + switch (frame->Opcode) + { + // When we receive a "ping" frame, we want to send a Pong frame back to the server. + case SocketOpcode::Ping: + Log::Write( + Logger::Level::Verbose, "Received Ping frame: " + HexEncode(frame->Payload, 16)); + SendPong(frame->Payload, context); + break; + + // We want to ignore all incoming "Pong" frames. + case SocketOpcode::Pong: + Log::Write( + Logger::Level::Verbose, "Received Pong frame: " + HexEncode(frame->Payload, 16)); + break; + + case SocketOpcode::BinaryFrame: + m_currentMessageType = SocketMessageType::Binary; + return std::shared_ptr(new WebSocketBinaryFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + + case SocketOpcode::TextFrame: + m_currentMessageType = SocketMessageType::Text; + return std::shared_ptr(new WebSocketTextFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + + case SocketOpcode::Close: { + if (frame->Payload.size() < 2) + { + throw std::runtime_error("Close response buffer is too short."); + } + // Encode the payload for close according to RFC 6455 + // section 5.5.1. The first two bytes of the payload contain the status code. + // The remainder of the payload is a UTF-8 encoded string. + uint16_t errorCode = 0; + errorCode |= (frame->Payload[0] << 8) & 0xff00; + errorCode |= (frame->Payload[1] & 0x00ff); + + // We received a close frame, mark the socket as closed. Make sure we + // reacquire the state lock before setting the state to closed. + lock.lock(); + m_stateOwner = std::this_thread::get_id(); + m_state = SocketState::Closed; + + return std::shared_ptr(new WebSocketPeerCloseFrame( + errorCode, std::string(frame->Payload.begin() + 2, frame->Payload.end()))); + } + + // Continuation frames need to be treated somewhat specially. + // We depend on the fact that the protocol requires that a Continuation frame + // only be sent if it is part of a multi-frame message whose previous frame was a Text + // or Binary frame. + case SocketOpcode::Continuation: + if (m_currentMessageType == SocketMessageType::Text) + { + if (frame->IsFinalFrame) + { + m_currentMessageType = SocketMessageType::Unknown; + } + return std::shared_ptr(new WebSocketTextFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + } + else if (m_currentMessageType == SocketMessageType::Binary) + { + if (frame->IsFinalFrame) + { + m_currentMessageType = SocketMessageType::Unknown; + } + return std::shared_ptr(new WebSocketBinaryFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + } + else + { + m_receiveStatistics.FramesDroppedByProtocolError++; + throw std::runtime_error("Unknown message type and received continuation opcode"); + } + default: + throw std::runtime_error("Unknown frame type received."); + } + } + else + { + if (m_state != SocketState::Closed && m_state != SocketState::Closing) + { + throw std::runtime_error("Transport is at EOF, no frame to receive."); + } + // The socket was closed, most likely locally, so fake a close frame response. + return std::shared_ptr(new WebSocketPeerCloseFrame()); + } + + context.ThrowIfCancelled(); + } + } + + std::shared_ptr + WebSocketImplementation::ReceiveTransportFrame(Azure::Core::Context const& context) + { +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + auto payload = m_transport->NativeReceiveFrame(context); + m_receiveStatistics.FramesReceived++; + switch (payload.FrameType) + { + case WebSocketTransport::NativeWebSocketFrameType::Binary: + m_receiveStatistics.BinaryFramesReceived++; + return std::make_shared( + SocketOpcode::BinaryFrame, true, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::BinaryFragment: + m_receiveStatistics.BinaryFramesReceived++; + return std::make_shared( + SocketOpcode::BinaryFrame, false, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::Text: + m_receiveStatistics.TextFramesReceived++; + return std::make_shared( + SocketOpcode::TextFrame, true, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::TextFragment: + m_receiveStatistics.TextFramesReceived++; + return std::make_shared( + SocketOpcode::TextFrame, false, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::Closed: { + m_receiveStatistics.CloseFramesReceived++; + auto closeResult = m_transport->NativeGetCloseSocketInformation(context); + std::vector closePayload; + closePayload.push_back(closeResult.CloseReason >> 8); + closePayload.push_back(closeResult.CloseReason & 0xff); + closePayload.insert( + closePayload.end(), + closeResult.CloseReasonDescription.begin(), + closeResult.CloseReasonDescription.end()); + return std::make_shared(SocketOpcode::Close, true, closePayload); + } + default: + throw std::runtime_error("Unexpected frame type received."); + } + } + else +#endif + { + std::shared_ptr frame = DecodeFrame(context); + if (frame) + { + + // Handle statistics for the incoming frame. + m_receiveStatistics.FramesReceived++; + switch (frame->Opcode) + { + case SocketOpcode::Ping: { + m_receiveStatistics.PingFramesReceived++; + break; + } + case SocketOpcode::Pong: { + m_receiveStatistics.PongFramesReceived++; + break; + } + case SocketOpcode::TextFrame: { + m_receiveStatistics.TextFramesReceived++; + break; + } + case SocketOpcode::BinaryFrame: { + m_receiveStatistics.BinaryFramesReceived++; + break; + } + case SocketOpcode::Close: { + m_receiveStatistics.CloseFramesReceived++; + break; + } + case SocketOpcode::Continuation: { + m_receiveStatistics.ContinuationFramesReceived++; + break; + } + default: { + m_receiveStatistics.UnknownFramesReceived++; + break; + } + } + } + else + { + m_receiveStatistics.FramesDropped++; + } + return frame; + } + } + + WebSocketStatistics WebSocketImplementation::GetStatistics() const + { + WebSocketStatistics returnValue{}; + returnValue.FramesSent = m_receiveStatistics.FramesSent.load(); + returnValue.FramesReceived = m_receiveStatistics.FramesReceived.load(); + returnValue.BinaryFramesReceived = m_receiveStatistics.BinaryFramesReceived.load(); + returnValue.TextFramesReceived = m_receiveStatistics.TextFramesReceived.load(); + returnValue.BinaryFramesSent = m_receiveStatistics.BinaryFramesSent.load(); + returnValue.TextFramesSent = m_receiveStatistics.TextFramesSent.load(); + returnValue.PingFramesReceived = m_receiveStatistics.PingFramesReceived.load(); + returnValue.PongFramesReceived = m_receiveStatistics.PongFramesReceived.load(); + returnValue.PingFramesSent = m_receiveStatistics.PingFramesSent.load(); + returnValue.PongFramesSent = m_receiveStatistics.PongFramesSent.load(); + + returnValue.BytesSent = m_receiveStatistics.BytesSent.load(); + returnValue.BytesReceived = m_receiveStatistics.BytesReceived.load(); + returnValue.FramesDropped = m_receiveStatistics.FramesDropped.load(); + returnValue.FramesDroppedByClose = m_receiveStatistics.FramesDroppedByClose.load(); + returnValue.FramesDroppedByPayloadSizeLimit + = m_receiveStatistics.FramesDroppedByPayloadSizeLimit.load(); + returnValue.FramesDroppedByProtocolError + = m_receiveStatistics.FramesDroppedByProtocolError.load(); + returnValue.TransportReadBytes = m_receiveStatistics.TransportReadBytes.load(); + returnValue.TransportReads = m_receiveStatistics.TransportReads.load(); + return returnValue; + } + + std::vector WebSocketImplementation::EncodeFrame( + SocketOpcode opcode, + bool isFinal, + std::vector const& payload) + { + std::vector encodedFrame; + // Add opcode+fin. + encodedFrame.push_back(static_cast(opcode) | (isFinal ? 0x80 : 0)); + uint8_t maskAndLength = 0; + maskAndLength |= 0x80; + + // Payloads smaller than 125 bytes are encoded directly in the maskAndLength field. + uint64_t payloadSize = static_cast(payload.size()); + if (payloadSize <= 125) + { + maskAndLength |= static_cast(payload.size()); + } + else if (payloadSize <= 65535) + { + // Payloads greater than 125 whose size can fit in a 16 bit integer bytes + // are encoded as a 16 bit unsigned integer in network byte order. + maskAndLength |= 126; + } + else + { + // Payloads greater than 65536 have their length are encoded as a 64 bit unsigned integer + // in network byte order. + maskAndLength |= 127; + } + encodedFrame.push_back(maskAndLength); + // Encode a 16 bit length. + if (payloadSize > 125 && payloadSize <= 65535) + { + encodedFrame.push_back(static_cast(payload.size()) >> 8); + encodedFrame.push_back(static_cast(payload.size()) & 0xff); + } + // Encode a 64 bit length. + else if (payloadSize >= 65536) + { + + encodedFrame.push_back((payloadSize >> 56) & 0xff); + encodedFrame.push_back((payloadSize >> 48) & 0xff); + encodedFrame.push_back((payloadSize >> 40) & 0xff); + encodedFrame.push_back((payloadSize >> 32) & 0xff); + encodedFrame.push_back((payloadSize >> 24) & 0xff); + encodedFrame.push_back((payloadSize >> 16) & 0xff); + encodedFrame.push_back((payloadSize >> 8) & 0xff); + encodedFrame.push_back(payloadSize & 0xff); + } + // Calculate the masking key. This MUST be 4 bytes of high entropy random numbers used to + // mask the input data. + { + // Start by generating the mask - 4 bytes of random data. + std::vector mask = GenerateRandomBytes(4); + + // Append the mask to the payload. + encodedFrame.insert(encodedFrame.end(), mask.begin(), mask.end()); + + // And mask the payload before transmitting it. + size_t index = 0; + for (auto ch : payload) + { + encodedFrame.push_back(ch ^ mask[index % 4]); + index += 1; + } + } + + return encodedFrame; + } + std::shared_ptr + WebSocketImplementation::DecodeFrame(Azure::Core::Context const& context) + { + // Ensure single threaded access to receive this frame. + std::unique_lock lock(m_transportMutex); + if (IsTransportEof()) + { + throw std::runtime_error("Frame buffer is too small."); + } + uint8_t payloadByte = ReadTransportByte(context); + // If the transport is at EOF, then there is no payload data, so just return null. + if (IsTransportEof()) + { + return nullptr; + } + SocketOpcode opcode = static_cast(payloadByte & 0x7f); + bool isFinal = (payloadByte & 0x80) != 0; + payloadByte = ReadTransportByte(context); + if (IsTransportEof()) + { + return nullptr; + } + if (payloadByte & 0x80) + { + throw std::runtime_error("Server sent a frame with a reserved bit set."); + } + int64_t payloadLength = payloadByte & 0x7f; + if (payloadLength <= 125) + { + payloadByte += 1; + } + else if (payloadLength == 126) + { + payloadLength = ReadTransportShort(context); + } + else if (payloadLength == 127) + { + payloadLength = ReadTransportInt64(context); + } + else + { + throw std::logic_error("Unexpected payload length."); + } + if (IsTransportEof()) + { + return nullptr; + } + + std::vector payload(ReadTransportBytes(static_cast(payloadLength), context)); + if (IsTransportEof()) + { + return nullptr; + } + return std::make_shared(opcode, isFinal, payload); + } + + uint8_t WebSocketImplementation::ReadTransportByte(Azure::Core::Context const& context) + { + if (m_bufferPos >= m_bufferLen) + { + // Start by reading data from our initial body stream. + m_bufferLen = m_initialBodyStream->ReadToCount(m_buffer, m_bufferSize, context); + if (m_bufferLen == 0) + { + // If we run out of the initial stream, we need to read from the transport. + m_bufferLen = m_transport->ReadFromSocket(m_buffer, m_bufferSize, context); + m_receiveStatistics.TransportReads++; + m_receiveStatistics.TransportReadBytes += static_cast(m_bufferLen); + } + else + { + Azure::Core::Diagnostics::_internal::Log::Write( + Azure::Core::Diagnostics::Logger::Level::Informational, + "Read data from initial stream"); + } + m_bufferPos = 0; + if (m_bufferLen == 0) + { + m_eof = true; + return 0; + } + } + + m_receiveStatistics.BytesReceived++; + return m_buffer[m_bufferPos++]; + } + uint16_t WebSocketImplementation::ReadTransportShort(Azure::Core::Context const& context) + { + uint16_t result = ReadTransportByte(context); + result <<= 8; + result |= ReadTransportByte(context); + return result; + } + uint64_t WebSocketImplementation::ReadTransportInt64(Azure::Core::Context const& context) + { + uint64_t result = 0; + + result |= (static_cast(ReadTransportByte(context)) << 56 & 0xff00000000000000); + result |= (static_cast(ReadTransportByte(context)) << 48 & 0x00ff000000000000); + result |= (static_cast(ReadTransportByte(context)) << 40 & 0x0000ff0000000000); + result |= (static_cast(ReadTransportByte(context)) << 32 & 0x000000ff00000000); + result |= (static_cast(ReadTransportByte(context)) << 24 & 0x00000000ff000000); + result |= (static_cast(ReadTransportByte(context)) << 16 & 0x0000000000ff0000); + result |= (static_cast(ReadTransportByte(context)) << 8 & 0x000000000000ff00); + result |= static_cast(ReadTransportByte(context)); + return result; + } + std::vector WebSocketImplementation::ReadTransportBytes( + size_t readLength, + Azure::Core::Context const& context) + { + std::vector result; + size_t index = 0; + while (index < readLength) + { + uint8_t byte = ReadTransportByte(context); + result.push_back(byte); + index += 1; + } + return result; + } + + void WebSocketImplementation::SendTransportBuffer( + std::vector const& sendFrame, + Azure::Core::Context const& context) + { + std::unique_lock transportLock(m_transportMutex); + m_receiveStatistics.BytesSent += static_cast(sendFrame.size()); + m_receiveStatistics.FramesSent += 1; + m_transport->SendBuffer(sendFrame.data(), sendFrame.size(), context); + } + + // Verify the Sec-WebSocket-Accept header as defined in RFC 6455 Section 1.3, which defines + // the opening handshake used for establishing the WebSocket connection. + std::string acceptHeaderGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + void WebSocketImplementation::VerifySocketAccept( + std::string const& encodedKey, + std::string const& acceptHeader) + { + std::string concatenatedKey(encodedKey); + concatenatedKey += acceptHeaderGuid; + Azure::Core::Cryptography::_internal::Sha1Hash sha1hash; + + sha1hash.Append( + reinterpret_cast(concatenatedKey.data()), concatenatedKey.size()); + auto keyHash = sha1hash.Final(); + std::string encodedHash = Azure::Core::Convert::Base64Encode(keyHash); + if (encodedHash != acceptHeader) + { + throw std::runtime_error( + "Hash returned by WebSocket server does not match expected hash. Aborting"); + } + } + + WebSocketImplementation::PingThread::PingThread( + WebSocketImplementation* socketImplementation, + std::chrono::duration pingInterval) + : m_webSocketImplementation(socketImplementation), m_pingInterval(pingInterval) + { + } + void WebSocketImplementation::PingThread::Start(std::shared_ptr transport) + { + m_stop = false; + // Spin up a thread to receive data from the transport. + if (!transport->HasBuiltInWebSocketSupport()) + { + std::unique_lock lock(m_pingThreadStarted); + m_pingThread = std::thread{&PingThread::PingThreadLoop, this}; + m_pingThreadReady.wait(lock); + } + } + + WebSocketImplementation::PingThread::~PingThread() + { + // Ensure that the receive thread is stopped. + Shutdown(); + } + void WebSocketImplementation::PingThread::Shutdown() + { + if (m_pingThread.joinable()) + { + std::unique_lock lock(m_stopMutex); + m_stop = true; + lock.unlock(); + m_pingThreadStopped.notify_all(); + + m_pingThread.join(); + } + } + + void WebSocketImplementation::PingThread::PingThreadLoop() + { + Log::Write(Logger::Level::Verbose, "Start Ping Thread Loop."); + { + std::unique_lock lock(m_pingThreadStarted); + m_pingThreadReady.notify_all(); + } + while (true) + { + std::unique_lock lock(m_stopMutex); + if (this->m_pingThreadStopped.wait_for(lock, m_pingInterval) == std::cv_status::timeout) + { + Log::Write(Logger::Level::Verbose, "Send Ping to peer."); + + // The receiveContext timed out, this means we timed out our "ping" timeout. + // Send a "Ping" request to the remote node. + auto pingData = GenerateRandomBytes(4); + SendPing(pingData, Azure::Core::Context{}); + } + if (m_stop) + { + Log::Write(Logger::Level::Verbose, "Exiting ping thread"); + return; + } + } + } + + bool WebSocketImplementation::PingThread::SendPing( + std::vector const& pingData, + Azure::Core::Context const& context) + { + std::vector pingFrame = EncodeFrame(SocketOpcode::Ping, true, pingData); + m_webSocketImplementation->m_receiveStatistics.PingFramesSent++; + m_webSocketImplementation->SendTransportBuffer(pingFrame, context); + return true; + } + + void WebSocketImplementation::SendPong( + std::vector const& pongData, + Azure::Core::Context const& context) + { + std::vector pongFrame = EncodeFrame(SocketOpcode::Pong, true, pongData); + + m_receiveStatistics.PongFramesSent++; + SendTransportBuffer(pongFrame, context); + } + + // Generator for random bytes. Used in WebSocketImplementation and tests. + std::vector GenerateRandomBytes(size_t vectorSize) + { + std::random_device randomEngine; + + std::vector rv(vectorSize); + std::generate(begin(rv), end(rv), [&randomEngine]() mutable { + return static_cast(randomEngine() % UINT8_MAX); + }); + return rv; + } +}}}}} // namespace Azure::Core::Http::WebSockets::_detail diff --git a/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp b/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp new file mode 100644 index 000000000..48f17484a --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp @@ -0,0 +1,373 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/internal/http/pipeline.hpp" +#include +#include +#include +#include +#include +#include + +// Implementation of WebSocket protocol. +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail { + + // Generator for random bytes. Used in WebSocketImplementation and tests. + std::vector GenerateRandomBytes(size_t vectorSize); + + class WebSocketImplementation { + enum class SocketState + { + Invalid, + Closed, + Opening, + Open, + Closing, + }; + + public: + WebSocketImplementation( + Azure::Core::Url const& remoteUrl, + _internal::WebSocketOptions const& options); + + void Open(Azure::Core::Context const& context); + void Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context); + void SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context); + void SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context); + + std::shared_ptr<_internal::WebSocketFrame> ReceiveFrame(Azure::Core::Context const& context); + + void AddHeader(std::string const& headerName, std::string const& headerValue); + + std::string const& GetNegotiatedProtocol(); + bool IsOpen() { return m_state == SocketState::Open; } + bool HasBuiltInWebSocketSupport(); + + _internal::WebSocketStatistics GetStatistics() const; + + private: + // WebSocket opcodes. + enum class SocketOpcode : uint8_t + { + Continuation = 0x00, + TextFrame = 0x01, + BinaryFrame = 0x02, + Close = 0x08, + Ping = 0x09, + Pong = 0x0a + }; + + /** + * Indicates the type of the message currently being processed. Used when processing + * Continuation Opcode frames. + */ + enum class SocketMessageType : int + { + Unknown, + Text, + Binary, + }; + + class WebSocketInternalFrame { + public: + SocketOpcode Opcode{}; + bool IsFinalFrame{false}; + std::vector Payload; + std::exception_ptr Exception; + WebSocketInternalFrame( + SocketOpcode opcode, + bool isFinalFrame, + std::vector const& payload) + : Opcode(opcode), IsFinalFrame(isFinalFrame), Payload(payload) + { + } + WebSocketInternalFrame(std::exception_ptr exception) : Exception(exception) {} + }; + + struct ReceiveStatistics + { + std::atomic FramesSent; + std::atomic FramesReceived; + std::atomic BytesSent; + std::atomic BytesReceived; + std::atomic PingFramesSent; + std::atomic PingFramesReceived; + std::atomic PongFramesSent; + std::atomic PongFramesReceived; + std::atomic TextFramesReceived; + std::atomic BinaryFramesReceived; + std::atomic ContinuationFramesReceived; + std::atomic CloseFramesReceived; + std::atomic UnknownFramesReceived; + std::atomic FramesDropped; + std::atomic FramesDroppedByPayloadSizeLimit; + std::atomic FramesDroppedByProtocolError; + std::atomic TransportReads; + std::atomic TransportReadBytes; + std::atomic BinaryFramesSent; + std::atomic TextFramesSent; + std::atomic FramesDroppedByClose; + + void Reset() + { + FramesSent = 0; + BytesSent = 0; + FramesReceived = 0; + BytesReceived = 0; + PingFramesReceived = 0; + PingFramesSent = 0; + PongFramesReceived = 0; + PongFramesSent = 0; + TextFramesReceived = 0; + TextFramesSent = 0; + BinaryFramesReceived = 0; + BinaryFramesSent = 0; + ContinuationFramesReceived = 0; + CloseFramesReceived = 0; + UnknownFramesReceived = 0; + FramesDropped = 0; + FramesDroppedByClose = 0; + FramesDroppedByPayloadSizeLimit = 0; + FramesDroppedByProtocolError = 0; + TransportReads = 0; + TransportReadBytes = 0; + } + }; + /** + * @brief The PingThread handles sending Ping operations from the WebSocket server. + * + */ + class PingThread { + public: + /** + * @brief Construct a new ReceiveQueue object. + * + * @param webSocketImplementation Parent object, used to send Ping threads. + * @param pingInterval Interval to wait between sending pings. + */ + PingThread( + WebSocketImplementation* webSocketImplementation, + std::chrono::duration pingInterval); + /** + * @brief Destroys a ReceiveQueue object. Blocks until the queue thread is completed. + */ + ~PingThread(); + + /** + * @brief Start the receive queue. This will start a thread that will process incoming frames. + * + * @param transport The websocket transport to use for receiving frames. + */ + void Start(std::shared_ptr transport); + /** + * @brief Stop the receive queue. This will stop the thread that processes incoming frames. + */ + void Shutdown(); + + private: + /** + * @brief The receive queue thread. + */ + void PingThreadLoop(); + /** + * @brief Send a "ping" frame to the other side of the WebSocket. + * + * @returns True if the ping was sent, false if the underlying transport didn't support "Ping" + * operations. + */ + bool SendPing(std::vector const& pingData, Azure::Core::Context const& context); + + WebSocketImplementation* m_webSocketImplementation; + std::chrono::duration m_pingInterval; + std::thread m_pingThread; + std::mutex m_pingThreadStarted; + std::condition_variable m_pingThreadReady; + + std::mutex m_stopMutex; + std::condition_variable m_pingThreadStopped; + bool m_stop = false; + }; + + /** + * @brief Encode a websocket frame according to RFC 6455 section 5.2. + * + * This wire format for the data transfer part is described by the ABNF + * [RFC5234] given in detail in this section. (Note that, unlike in + * other sections of this document, the ABNF in this section is + * operating on groups of bits. The length of each group of bits is + * indicated in a comment. When encoded on the wire, the most + * significant bit is the leftmost in the ABNF). A high-level overview + * of the framing is given in the following figure. In a case of + * conflict between the figure below and the ABNF specified later in + * this section, the figure is authoritative. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-------+-+-------------+-------------------------------+ + * |F|R|R|R| opcode|M| Payload len | Extended payload length | + * |I|S|S|S| (4) |A| (7) | (16/64) | + * |N|V|V|V| |S| | (if payload len==126/127) | + * | |1|2|3| |K| | | + * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + * | Extended payload length continued, if payload len == 127 | + * + - - - - - - - - - - - - - - - +-------------------------------+ + * | |Masking-key, if MASK set to 1 | + * +-------------------------------+-------------------------------+ + * | Masking-key (continued) | Payload Data | + * +-------------------------------- - - - - - - - - - - - - - - - + + * : Payload Data continued ... : + * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + * | Payload Data continued ... | + * +---------------------------------------------------------------+ + * + * FIN: 1 bit + * + * Indicates that this is the final fragment in a message. The first + * fragment MAY also be the final fragment. + * + * RSV1, RSV2, RSV3: 1 bit each + * + * MUST be 0 unless an extension is negotiated that defines meanings + * for non-zero values. If a nonzero value is received and none of + * the negotiated extensions defines the meaning of such a nonzero + * value, the receiving endpoint MUST _Fail the WebSocket + * Connection_. + * + * Opcode: 4 bits + * + * Defines the interpretation of the "Payload data". If an unknown + * opcode is received, the receiving endpoint MUST _Fail the + * WebSocket Connection_. The following values are defined. + * + * * %x0 denotes a continuation frame + * + * * %x1 denotes a text frame + * + * * %x2 denotes a binary frame + * + * * %x3-7 are reserved for further non-control frames + * + * * %x8 denotes a connection close + * + * * %x9 denotes a ping + * + * * %xA denotes a pong + * + * * %xB-F are reserved for further control frames + * + * Mask: 1 bit + * + * Defines whether the "Payload data" is masked. If set to 1, a + * masking key is present in masking-key, and this is used to unmask + * the "Payload data" as per Section 5.3. All frames sent from + * client to server have this bit set to 1. + * + * Payload length: 7 bits, 7+16 bits, or 7+64 bits + * + * The length of the "Payload data", in bytes: if 0-125, that is the + * payload length. If 126, the following 2 bytes interpreted as a + * 16-bit unsigned integer are the payload length. If 127, the + * following 8 bytes interpreted as a 64-bit unsigned integer (the + * most significant bit MUST be 0) are the payload length. Multibyte + * length quantities are expressed in network byte order. Note that + * in all cases, the minimal number of bytes MUST be used to encode + * the length, for example, the length of a 124-byte-long string + * can't be encoded as the sequence 126, 0, 124. The payload length + * is the length of the "Extension data" + the length of the + * "Application data". The length of the "Extension data" may be + * zero, in which case the payload length is the length of the + * "Application data". + * Masking-key: 0 or 4 bytes + * + * All frames sent from the client to the server are masked by a + * 32-bit value that is contained within the frame. This field is + * present if the mask bit is set to 1 and is absent if the mask bit + * is set to 0. See Section 5.3 for further information on client- + * to-server masking. + * + * Payload data: (x+y) bytes + * + * The "Payload data" is defined as "Extension data" concatenated + * with "Application data". + * + * Extension data: x bytes + * + * The "Extension data" is 0 bytes unless an extension has been + * negotiated. Any extension MUST specify the length of the + * "Extension data", or how that length may be calculated, and how + * the extension use MUST be negotiated during the opening handshake. + * If present, the "Extension data" is included in the total payload + * length. + * + * Application data: y bytes + * + * Arbitrary "Application data", taking up the remainder of the frame + * after any "Extension data". The length of the "Application data" + * is equal to the payload length minus the length of the "Extension + * data". + */ + static std::vector EncodeFrame( + SocketOpcode opcode, + bool isFinal, + std::vector const& payload); + + SocketState m_state{SocketState::Invalid}; + + std::vector GenerateRandomKey() { return GenerateRandomBytes(16); }; + void VerifySocketAccept(std::string const& encodedKey, std::string const& acceptHeader); + + /********* + * Buffered Read Support. Read data from the underlying transport into a buffer. + */ + uint8_t ReadTransportByte(Azure::Core::Context const& context); + uint16_t ReadTransportShort(Azure::Core::Context const& context); + uint64_t ReadTransportInt64(Azure::Core::Context const& context); + std::vector ReadTransportBytes(size_t readLength, Azure::Core::Context const& context); + bool IsTransportEof() const { return m_eof; } + void SendPong(std::vector const& pongData, Azure::Core::Context const& context); + void SendTransportBuffer( + std::vector const& payload, + Azure::Core::Context const& context); + std::shared_ptr ReceiveTransportFrame( + Azure::Core::Context const& context); + + /** + * @brief Decode a frame received from the websocket server. + * + * @returns A pointer to the start of the decoded data. + */ + std::shared_ptr DecodeFrame(Azure::Core::Context const& context); + + Azure::Core::Url m_remoteUrl; + _internal::WebSocketOptions m_options; + std::map m_headers; + std::string m_chosenProtocol; + std::shared_ptr m_transport; + PingThread m_pingThread; + SocketMessageType m_currentMessageType{SocketMessageType::Unknown}; + std::mutex m_stateMutex; + std::thread::id m_stateOwner; + + ReceiveStatistics m_receiveStatistics{}; + + std::mutex m_transportMutex; + + std::unique_ptr m_initialBodyStream; + constexpr static size_t m_bufferSize = 1024; + uint8_t m_buffer[m_bufferSize]{}; + size_t m_bufferPos = 0; + size_t m_bufferLen = 0; + bool m_eof = false; + }; +}}}}} // namespace Azure::Core::Http::WebSockets::_detail diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp new file mode 100644 index 000000000..7d869ca70 --- /dev/null +++ b/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/http.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/win_http_websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/platform.hpp" + +#if defined(AZ_PLATFORM_POSIX) +#include // for poll() +#include // for socket shutdown +#elif defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include // for WSAPoll(); +#endif +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + void WinHttpWebSocketTransport::OnUpgradedConnection( + Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) + { + // Convert the request handle into a WebSocket handle for us to use later. + m_socketHandle = Azure::Core::Http::_detail::unique_HINTERNET( + WinHttpWebSocketCompleteUpgrade(requestHandle.get(), 0), + Azure::Core::Http::_detail::HINTERNET_deleter{}); + if (!m_socketHandle) + { + GetErrorAndThrow("Error Upgrading HttpRequest handle to WebSocket handle."); + } + } + + std::unique_ptr WinHttpWebSocketTransport::Send( + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) + { + return WinHttpTransport::Send(request, context); + } + + /** + * @brief Close the WebSocket cleanly. + */ + void WinHttpWebSocketTransport::Close() { m_socketHandle.reset(); } + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + * + */ + void WinHttpWebSocketTransport::NativeCloseSocket( + uint16_t status, + std::string const& disconnectReason, + Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + + auto err = WinHttpWebSocketClose( + m_socketHandle.get(), + status, + disconnectReason.empty() + ? nullptr + : reinterpret_cast(const_cast(disconnectReason.c_str())), + static_cast(disconnectReason.size())); + if (err != 0) + { + GetErrorAndThrow("WinHttpWebSocketClose() failed", err); + } + + context.ThrowIfCancelled(); + + // Make sure that the server responds gracefully to the close request. + auto closeInformation = NativeGetCloseSocketInformation(context); + + // The server should return the same status we sent. + if (closeInformation.CloseReason != status) + { + throw std::runtime_error( + "Close status mismatch, got " + std::to_string(closeInformation.CloseReason) + + " expected " + std::to_string(status)); + } + } + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + WinHttpWebSocketTransport::NativeWebSocketCloseInformation + WinHttpWebSocketTransport::NativeGetCloseSocketInformation(Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + uint16_t closeStatus = 0; + char closeReason[WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH]{}; + DWORD closeReasonLength; + + auto err = WinHttpWebSocketQueryCloseStatus( + m_socketHandle.get(), + &closeStatus, + closeReason, + WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH, + &closeReasonLength); + if (err != 0) + { + GetErrorAndThrow("WinHttpGetCloseStatus() failed", err); + } + return NativeWebSocketCloseInformation{closeStatus, std::string(closeReason)}; + } + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native + * websockets. + * + * @brief frameType Frame type sent to the server, Text or Binary. + * @brief frameData Frame data to be sent to the server. + */ + void WinHttpWebSocketTransport::NativeSendFrame( + NativeWebSocketFrameType frameType, + std::vector const& frameData, + Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType; + switch (frameType) + { + case NativeWebSocketFrameType::Text: + bufferType = WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::Binary: + bufferType = WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::BinaryFragment: + bufferType = WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::TextFragment: + bufferType = WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE; + break; + default: + throw std::runtime_error( + "Unknown frame type: " + std::to_string(static_cast(frameType))); + break; + } + // Lock the socket to prevent concurrent writes. WinHTTP gets annoyed if + // there are multiple WinHttpWebSocketSend requests outstanding. + std::lock_guard lock(m_sendMutex); + auto err = WinHttpWebSocketSend( + m_socketHandle.get(), + bufferType, + reinterpret_cast(const_cast(frameData.data())), + static_cast(frameData.size())); + if (err != 0) + { + GetErrorAndThrow("WinHttpWebSocketSend() failed", err); + } + } + + WinHttpWebSocketTransport::NativeWebSocketReceiveInformation + WinHttpWebSocketTransport::NativeReceiveFrame(Azure::Core::Context const& context) + { + WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType; + NativeWebSocketFrameType frameTypeReceived; + DWORD bufferBytesRead; + std::vector buffer(128); + context.ThrowIfCancelled(); + std::lock_guard lock(m_receiveMutex); + + auto err = WinHttpWebSocketReceive( + m_socketHandle.get(), + reinterpret_cast(buffer.data()), + static_cast(buffer.size()), + &bufferBytesRead, + &bufferType); + if (err != 0 && err != ERROR_INSUFFICIENT_BUFFER) + { + GetErrorAndThrow("WinHttpWebSocketReceive() failed", err); + } + buffer.resize(bufferBytesRead); + + switch (bufferType) + { + case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Text; + break; + case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Binary; + break; + case WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::BinaryFragment; + break; + case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::TextFragment; + break; + case WINHTTP_WEB_SOCKET_CLOSE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Closed; + break; + default: + throw std::runtime_error("Unknown frame type: " + std::to_string(bufferType)); + break; + } + return NativeWebSocketReceiveInformation{frameTypeReceived, buffer}; + } + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 b/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 index acfc9bf0d..7e0828098 100644 --- a/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 +++ b/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-License-Identifier: MIT param( [string] $LogFileLocation = "$($env:BUILD_SOURCESDIRECTORY)/WebSocketServer.log" ) diff --git a/sdk/core/azure-core/test/ut/websocket_server.py b/sdk/core/azure-core/test/ut/websocket_server.py new file mode 100644 index 000000000..ca7bcc077 --- /dev/null +++ b/sdk/core/azure-core/test/ut/websocket_server.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-License-Identifier: MIT +from array import array +import asyncio +from operator import length_hint +import threading +from time import sleep +from urllib.parse import ParseResult, urlparse + +import websockets + +# create handler for each connection +customPaths = {} +stop = False + +async def handleControlPath(websocket): + while (1): + data : str = await websocket.recv() + parsedCommand = data.split(' ') + if (parsedCommand[0] == "close"): + print("Closing control channel") + await websocket.send("ok") + print("Terminating WebSocket server.") + stop.set_result(0) + break + elif parsedCommand[0] == "newPath": + print("Add path") + newPath = parsedCommand[1] + print(" Add path ", newPath) + customPaths[newPath] = {"path": newPath, "delay": int(parsedCommand[2]) } + await websocket.send("ok") + else: + print("Unknown command, echoing it.") + await websocket.send(data) + +async def handleCustomPath(websocket, path:dict): + print("Handle custom path", path) + data : str = await websocket.recv() + print("Received ", data) + if ("delay" in path.keys()): + sleep(path["delay"]) + print("Responding") + await websocket.send(data) + await websocket.close() + +def HexEncode(data: bytes)->str: + rv="" + for val in data: + rv+= '{:02X}'.format(val) + return rv + +def ParseQuery(url : ParseResult) -> dict: + rv={} + if len(url.query)!=0: + args = url.query.split('&') + for arg in args: + vals=arg.split('=') + rv[vals[0]]=vals[1] + return rv + +echo_count_lock = threading.Lock() +echo_count_recv = 0 +echo_count_send = 0 +client_count = 0 +async def handleEcho(websocket, url:ParseResult): + global client_count + global echo_count_recv + global echo_count_send + global echo_count_lock + queryValues = ParseQuery(url) + while websocket.open: + try: + data = await websocket.recv() + with echo_count_lock: + echo_count_recv+=1 + if 'delay' in queryValues: + print(f"sleeping for {queryValues['delay']} seconds") + await asyncio.sleep(float(queryValues['delay'])) + print("woken up.") + + if 'fragment' in queryValues and queryValues['fragment']=='true': + await websocket.send(data.split()) + else: + await websocket.send(data) + with echo_count_lock: + echo_count_send+=1 + except websockets.ConnectionClosedOK: + print("Connection closed ok.") + with echo_count_lock: + client_count -= 1 + print(f"Echo count: {echo_count_recv}, {echo_count_send} client_count {client_count}") + if client_count == 0: + echo_count_send = 0 + echo_count_recv = 0 + return + except websockets.ConnectionClosed as ex: + if (ex.rcvd): + print(f"Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}") + else: + print(f"Connection closed. No close information.") + with echo_count_lock: + client_count -= 1 + print(f"Echo count: recv: {echo_count_recv}, send: {echo_count_send} client_count {client_count}") + if client_count == 0: + echo_count_send = 0 + echo_count_recv = 0 + return + +async def handler(websocket, path : str): + global client_count + print("Socket handler: ", path) + parsedUrl = urlparse(path) + if (parsedUrl.path == '/openclosetest'): + print("Open/Close Test") + try: + data = await websocket.recv() + print(f"OpenCloseTest: Received {data}") + except websockets.ConnectionClosedOK: + print("OpenCloseTest: Connection closed ok.") + except websockets.ConnectionClosed as ex: + print(f"OpenCloseTest: Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}") + return + elif (parsedUrl.path == '/echotest'): + with echo_count_lock: + client_count+= 1 + await handleEcho(websocket, parsedUrl) + elif (parsedUrl.path == '/closeduringecho'): + data = await websocket.recv() + await websocket.close(1001, 'closed') + elif (parsedUrl.path =='/control'): + await handleControlPath(websocket) + elif (parsedUrl.path in customPaths.keys()): + print("Found path ", path, "in control paths.") + await handleCustomPath(websocket, customPaths[path]) + elif (parsedUrl.path == '/terminateserver'): + print("Terminating WebSocket server.") + stop.set_result(0) + else: + data = await websocket.recv() + print("Received: ", data) + + reply = f"Data received as: {data}!" + await websocket.send(reply) + +async def main(): + global stop + print("Starting server") + loop = asyncio.get_running_loop() + stop = loop.create_future() + async with websockets.serve(handler, "localhost", 8000, ping_interval=7): + await stop # run forever. + +if __name__=="__main__": + asyncio.run(main()) + print("Ending server") diff --git a/sdk/core/azure-core/test/ut/websocket_test.cpp b/sdk/core/azure-core/test/ut/websocket_test.cpp new file mode 100644 index 000000000..316093496 --- /dev/null +++ b/sdk/core/azure-core/test/ut/websocket_test.cpp @@ -0,0 +1,875 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "../../src/http/websockets/websockets_impl.hpp" +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/internal/json/json.hpp" +#include +#include +#include +#include +#include +#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#endif +// cspell::words closeme flibbityflobbidy + +using namespace Azure::Core; +using namespace Azure::Core::Http::WebSockets; +using namespace Azure::Core::Http::WebSockets::_internal; +using namespace std::chrono_literals; + +constexpr uint16_t UndefinedButLegalCloseReason = 4500; + +class WebSocketTests : public testing::Test { +private: +protected: + // Create + static void SetUpTestSuite() {} + static void TearDownTestSuite() {} +}; + +TEST_F(WebSocketTests, CreateSimpleSocket) +{ + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000")); + defaultSocket.AddHeader("newHeader", "headerValue"); + EXPECT_THROW(defaultSocket.GetNegotiatedProtocol(), std::runtime_error); + } +} + +TEST_F(WebSocketTests, OpenSimpleSocket) +{ + { + WebSocketOptions options; + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"), options); + defaultSocket.AddHeader("newHeader", "headerValue"); + + defaultSocket.Open(); + + EXPECT_THROW(defaultSocket.AddHeader("newHeader", "headerValue"), std::runtime_error); + + // Close the socket without notifying the peer. + defaultSocket.Close(); + } + + { + WebSocketOptions options; + WebSocket defaultSocket(Azure::Core::Url("http://www.microsoft.com/"), options); + defaultSocket.AddHeader("newHeader", "headerValue"); + + // When running this test locally, the call times out, so drop in a 5 second timeout on + // the request. + Azure::Core::Context requestContext = Azure::Core::Context::ApplicationContext.WithDeadline( + std::chrono::system_clock::now() + 5s); + EXPECT_THROW(defaultSocket.Open(requestContext), std::runtime_error); + } +} + +TEST_F(WebSocketTests, OpenAndCloseSocket) +{ + if (false) + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest")); + defaultSocket.AddHeader("newHeader", "headerValue"); + + defaultSocket.Open(); + + // Close the socket without notifying the peer. + defaultSocket.Close(UndefinedButLegalCloseReason); + } + + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest")); + + defaultSocket.Open(); + + // Close the socket without notifying the peer. + defaultSocket.Close(UndefinedButLegalCloseReason, "This is a good reason."); + + // + // Now re-open the socket - this should work to reset everything. + defaultSocket.Open(); + EXPECT_THROW(defaultSocket.Open(), std::runtime_error); + defaultSocket.Close(); + } +} + +TEST_F(WebSocketTests, SimpleEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + testSocket.SendFrame("Test message", true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::TextFrameReceived, response->FrameType); + EXPECT_THROW(response->AsBinaryFrame(), std::logic_error); + auto textResult = response->AsTextFrame(); + EXPECT_EQ("Test message", textResult->Text); + + // Close the socket gracefully. + testSocket.Close(); + } + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?delay=5")); + + testSocket.Open(); + + std::vector binaryData{1, 2, 3, 4, 5, 6}; + + testSocket.SendFrame(binaryData, true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + EXPECT_THROW(response->AsPeerCloseFrame(), std::logic_error); + EXPECT_THROW(response->AsTextFrame(), std::logic_error); + auto textResult = response->AsBinaryFrame(); + EXPECT_EQ(binaryData, textResult->Data); + + // Close the socket gracefully. + testSocket.Close(); + } + + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?fragment=true&delay=5")); + + testSocket.Open(); + + std::vector binaryData{1, 2, 3, 4, 5, 6}; + + testSocket.SendFrame(binaryData, true); + + std::vector responseData; + std::shared_ptr response; + do + { + response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + responseData.insert(responseData.end(), binaryResult->Data.begin(), binaryResult->Data.end()); + } while (!response->IsFinalFrame); + + auto textResult = response->AsBinaryFrame(); + EXPECT_EQ(binaryData, responseData); + + // Close the socket gracefully. + testSocket.Close(); + } +} + +template void EchoRandomData(WebSocket& socket) +{ + std::vector sendData = Azure::Core::Http::WebSockets::_detail::GenerateRandomBytes(N); + + socket.SendFrame(sendData, true); + + std::vector receiveData; + + std::shared_ptr response; + do + { + response = socket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + receiveData.insert(receiveData.end(), binaryResult->Data.begin(), binaryResult->Data.end()); + } while (!response->IsFinalFrame); + + // Make sure we get back the data we sent in the echo request. + EXPECT_EQ(sendData.size(), receiveData.size()); + EXPECT_EQ(sendData, receiveData); +} + +TEST_F(WebSocketTests, VariableSizeEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + { + EchoRandomData<100>(testSocket); + EchoRandomData<124>(testSocket); + EchoRandomData<125>(testSocket); + // The websocket protocol treats lengths of 125, 126 and > 127 specially. + EchoRandomData<126>(testSocket); + EchoRandomData<127>(testSocket); + EchoRandomData<128>(testSocket); + EchoRandomData<1020>(testSocket); // 1K-4 + EchoRandomData<1021>(testSocket); // 1K-3 + EchoRandomData<1022>(testSocket); // 1K-2 + EchoRandomData<1023>(testSocket); // 1K-1 + EchoRandomData<1024>(testSocket); // 1K + EchoRandomData<2048>(testSocket); // 2K + EchoRandomData<4096>(testSocket); // 4K + EchoRandomData<8192>(testSocket); // 8K + // The websocket protocol treats lengths of >65536 specially. + EchoRandomData<65535>(testSocket); // 64K-1 + EchoRandomData<65536>(testSocket); // 64K + EchoRandomData<65537>(testSocket); // 64K+1 + EchoRandomData<131072>(testSocket); // 128K + } + // Close the socket gracefully. + testSocket.Close(); + } +} + +// Generator for random bytes. Used in WebSocketImplementation and tests. +std::vector GenerateRandomBytes(size_t index, size_t vectorSize) +{ + std::random_device randomEngine; + + std::vector rv(vectorSize + 4); + rv[0] = index & 0xff; + rv[1] = (index >> 8) & 0xff; + rv[2] = (index >> 16) & 0xff; + rv[3] = (index >> 24) & 0xff; + std::generate(std::begin(rv) + 4, std::end(rv), [&randomEngine]() mutable { + return static_cast(randomEngine() % UINT8_MAX); + }); + return rv; +} + +TEST_F(WebSocketTests, CloseDuringEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho")); + + testSocket.Open(); + + testSocket.SendFrame("Test message", true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::PeerClosedReceived, response->FrameType); + auto PeerClosedReceived = response->AsPeerCloseFrame(); + EXPECT_EQ(1001, PeerClosedReceived->RemoteStatusCode); + + // Close the socket gracefully. + testSocket.Close(); + } + + // Close the websocket while a thread is waiting for a response. + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/echotest?delay=10")); + + testSocket.Open(); + + std::thread testThread([&]() { + try + { + std::vector sendData = GenerateRandomBytes(0, 100); + testSocket.SendFrame(sendData); + GTEST_LOG_(INFO) << "Receive frame."; + auto response = testSocket.ReceiveFrame(); + GTEST_LOG_(INFO) << "Received frame."; + if (response->FrameType == WebSocketFrameType::PeerClosedReceived) + { + GTEST_LOG_(INFO) << "Peer closed the socket; Terminating thread."; + return; + } + else if (response->FrameType != WebSocketFrameType::BinaryFrameReceived) + { + GTEST_LOG_(INFO) << "Unexpected frame type received."; + } + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() + << " Current Thread: " << std::this_thread::get_id() << std::endl; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + } + }); + + std::this_thread::sleep_for(100ms); + + // Close the socket gracefully. + GTEST_LOG_(INFO) << "Closing Socket."; + EXPECT_NO_THROW(testSocket.Close(UndefinedButLegalCloseReason, "Close Reason.")); + GTEST_LOG_(INFO) << "Closed Socket."; + testThread.join(); + } +} + +TEST_F(WebSocketTests, ExpectThrow) +{ + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho")); + + EXPECT_THROW(testSocket.SendFrame("Foo", true), std::runtime_error); + std::vector data{1, 2, 3, 4}; + EXPECT_THROW(testSocket.SendFrame(data, true), std::runtime_error); + EXPECT_THROW(testSocket.ReceiveFrame(), std::runtime_error); + } +} + +std::string ToHexString(std::vector const& data) +{ + std::stringstream ss; + for (auto const& byte : data) + { + ss << std::hex << std::setfill('0') << std::setw(2) << static_cast(byte); + } + return ss.str(); +} + +TEST_F(WebSocketTests, PingReceiveTest) +{ + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + if (!testSocket.HasBuiltInWebSocketSupport()) + { + + GTEST_LOG_(INFO) << "Sleeping for 15 seconds to collect pings."; + Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{std::chrono::system_clock::now() + 15s}); + EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException); + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + EXPECT_NE(0, statistics.PingFramesReceived); + EXPECT_NE(0, statistics.PongFramesSent); + } +} + +TEST_F(WebSocketTests, PingSendTest) +{ + // Configure the socket to ping every second. + WebSocketOptions socketOptions; + socketOptions.PingInterval = std::chrono::seconds(1); + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"), socketOptions); + + testSocket.Open(); + if (!testSocket.HasBuiltInWebSocketSupport()) + { + + GTEST_LOG_(INFO) << "Sleeping for 10 seconds to collect pings."; + // Note that we cannot collect incoming pings or outgoing pongs unless we are receiving + // data from the server. + Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{std::chrono::system_clock::now() + 10s}); + EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException); + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + EXPECT_NE(0, statistics.PingFramesSent); + EXPECT_NE(0, statistics.PongFramesReceived); + EXPECT_NE(0, statistics.PingFramesReceived); + EXPECT_NE(0, statistics.PongFramesSent); + } +} + +TEST_F(WebSocketTests, MultiThreadedTestOnSingleSocket) +{ + constexpr size_t threadCount = 50; + constexpr size_t testDataLength = 200000; + constexpr size_t testDataSize = 100; + constexpr auto testDuration = 10s; + + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + // seed test data for the operations. + std::vector> testData(testDataLength); + std::vector> receivedData(testDataLength); + std::atomic_size_t iterationCount(0); + + // Spin up threadCount threads and hammer the echo server for 10 seconds. + std::vector threads; + std::atomic_int32_t cancellationExceptions{0}; + std::atomic_int32_t exceptions{0}; + for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1) + { + threads.push_back(std::thread([&]() { + std::chrono::time_point startTime + = std::chrono::system_clock::now(); + // Set the context to expire *after* the test is supposed to finish. + Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{startTime} + testDuration + 10s); + size_t iteration = 0; + try + { + do + { + iteration = iterationCount++; + std::vector sendData = GenerateRandomBytes(iteration, testDataSize); + { + if (iteration < testData.size()) + { + if (testData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl; + } + EXPECT_EQ(0, testData[iteration].size()); + testData[iteration] = sendData; + } + } + + testSocket.SendFrame(sendData, true /*, context*/); + auto response = testSocket.ReceiveFrame(context); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + + // Make sure we get back the data we sent in the echo request. + if (binaryResult->Data.size() == 0) + { + GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl; + } + EXPECT_EQ(sendData.size(), binaryResult->Data.size()); + { + // There is no ordering expectation on the results, so we just remember the data + // as it comes in. We'll make sure we received everything later on. + if (iteration < receivedData.size()) + { + if (receivedData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration + << std::endl; + } + + EXPECT_EQ(0, receivedData[iteration].size()); + receivedData[iteration] = binaryResult->Data; + } + } + } while (std::chrono::system_clock::now() - startTime < testDuration); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration + << " Current Thread: " << std::this_thread::get_id() << std::endl; + cancellationExceptions++; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + exceptions++; + } + })); + } + + // Wait for all the threads to exit. + for (auto& thread : threads) + { + thread.join(); + } + + // We no longer need to worry about synchronization since all the worker threads are done. + GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl; + GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex + << testData.size() << ")" << std::endl; + EXPECT_GE(testDataLength, iterationCount.load()); + + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + + // Close the socket gracefully. + testSocket.Close(); + + EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesSent); + EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesReceived); + + // Resize the test data to the number of actual iterations. + testData.resize(iterationCount.load()); + receivedData.resize(iterationCount.load()); + + // If we've processed every iteration, let's make sure that we received everything we sent. + // If we dropped some results, then we can't check to ensure that we have received everything + // because we can't account for everything sent. + std::multiset testDataStrings; + std::multiset receivedDataStrings; + for (auto const& data : testData) + { + testDataStrings.emplace(ToHexString(data)); + } + for (auto const& data : receivedData) + { + receivedDataStrings.emplace(ToHexString(data)); + } + + EXPECT_EQ(testDataStrings, receivedDataStrings); + for (auto const& data : testDataStrings) + { + if (receivedDataStrings.count(data) != testDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data)); + } + for (auto const& data : receivedDataStrings) + { + if (testDataStrings.count(data) != receivedDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + + EXPECT_NE(testDataStrings.end(), testDataStrings.find(data)); + } + + // We shouldn't have seen any exceptions during the run. + EXPECT_EQ(0, exceptions.load()); + EXPECT_EQ(0, cancellationExceptions.load()); +} + +TEST_F(WebSocketTests, MultiThreadedTestOnMultipleSockets) +{ + constexpr size_t threadCount = 50; + constexpr size_t testDataLength = 200000; + constexpr size_t testDataSize = 100; + constexpr auto testDuration = 10s; + + // seed test data for the operations. + std::vector> testData(testDataLength); + std::vector> receivedData(testDataLength); + std::atomic_size_t iterationCount(0); + + // Spin up threadCount threads and hammer the echo server for 10 seconds. + std::vector threads; + std::atomic_int32_t cancellationExceptions{0}; + std::atomic_int32_t exceptions{0}; + for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1) + { + threads.push_back(std::thread([&]() { + std::chrono::time_point startTime + = std::chrono::system_clock::now(); + // Set the context to expire *after* the test is supposed to finish. + Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{startTime} + testDuration + 10s); + size_t iteration = 0; + try + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + do + { + iteration = iterationCount++; + std::vector sendData = GenerateRandomBytes(iteration, testDataSize); + { + if (iteration < testData.size()) + { + if (testData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl; + } + EXPECT_EQ(0, testData[iteration].size()); + testData[iteration] = sendData; + } + } + + testSocket.SendFrame(sendData, true /*, context*/); + auto response = testSocket.ReceiveFrame(context); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + + // Make sure we get back the data we sent in the echo request. + if (binaryResult->Data.size() == 0) + { + GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl; + } + EXPECT_EQ(sendData.size(), binaryResult->Data.size()); + { + // There is no ordering expectation on the results, so we just remember the data + // as it comes in. We'll make sure we received everything later on. + if (iteration < receivedData.size()) + { + if (receivedData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration + << std::endl; + } + + EXPECT_EQ(0, receivedData[iteration].size()); + receivedData[iteration] = binaryResult->Data; + } + } + } while (std::chrono::system_clock::now() - startTime < testDuration); + // Close the socket gracefully. + testSocket.Close(); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration + << " Current Thread: " << std::this_thread::get_id() << std::endl; + cancellationExceptions++; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + exceptions++; + } + })); + } + + // Wait for all the threads to exit. + for (auto& thread : threads) + { + thread.join(); + } + + // We no longer need to worry about synchronization since all the worker threads are done. + GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl; + GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex + << testData.size() << ")" << std::endl; + EXPECT_GE(testDataLength, iterationCount.load()); + + // Resize the test data to the number of actual iterations. + testData.resize(iterationCount.load()); + receivedData.resize(iterationCount.load()); + + // If we've processed every iteration, let's make sure that we received everything we sent. + // If we dropped some results, then we can't check to ensure that we have received everything + // because we can't account for everything sent. + std::multiset testDataStrings; + std::multiset receivedDataStrings; + for (auto const& data : testData) + { + testDataStrings.emplace(ToHexString(data)); + } + for (auto const& data : receivedData) + { + receivedDataStrings.emplace(ToHexString(data)); + } + + EXPECT_EQ(testDataStrings, receivedDataStrings); + for (auto const& data : testDataStrings) + { + if (receivedDataStrings.count(data) != testDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data)); + } + for (auto const& data : receivedDataStrings) + { + if (testDataStrings.count(data) != receivedDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + + EXPECT_NE(testDataStrings.end(), testDataStrings.find(data)); + } + + // We shouldn't have seen any exceptions during the run. + EXPECT_EQ(0, exceptions.load()); + EXPECT_EQ(0, cancellationExceptions.load()); +} + +// Does not work because curl rejects the wss: scheme. +class LibWebSocketIncrementProtocol { + WebSocketOptions m_options{{"dumb-increment-protocol"}}; + WebSocket m_socket; + +public: + LibWebSocketIncrementProtocol() : m_socket{Azure::Core::Url("wss://libwebsockets.org"), m_options} + { + } + + void Open() { m_socket.Open(); } + int GetNextNumber() + { + // Time out in 5 seconds if no activity. + Azure::Core::Context contextWithTimeout + = Azure::Core::Context().WithDeadline(std::chrono::system_clock::now() + 10s); + auto work = m_socket.ReceiveFrame(contextWithTimeout); + if (work->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto frame = work->AsTextFrame(); + return std::atoi(frame->Text.c_str()); + } + if (work->FrameType == WebSocketFrameType::BinaryFrameReceived) + { + auto frame = work->AsBinaryFrame(); + throw std::runtime_error("Not implemented"); + } + else if (work->FrameType == WebSocketFrameType::PeerClosedReceived) + { + GTEST_LOG_(INFO) << "Remote server closed connection." << std::endl; + throw std::runtime_error("Remote server closed connection."); + } + else + { + throw std::runtime_error("Unknown result type"); + } + } + + void Reset() { m_socket.SendFrame("reset\n", true); } + void RequestClose() { m_socket.SendFrame("closeme\n", true); } + void Close() { m_socket.Close(); } + void Close(uint16_t closeCode, std::string const& reasonText = {}) + { + m_socket.Close(closeCode, reasonText); + } + void ConsumeUntilClosed() + { + while (m_socket.IsOpen()) + { + auto work = m_socket.ReceiveFrame(); + if (work->FrameType == WebSocketFrameType::PeerClosedReceived) + { + auto peerClose = work->AsPeerCloseFrame(); + GTEST_LOG_(INFO) << "Peer closed. Remote Code: " << std::dec << peerClose->RemoteStatusCode + << " (0x" << std::hex << peerClose->RemoteStatusCode << ")" << std::endl; + if (!peerClose->RemoteCloseReason.empty()) + { + GTEST_LOG_(INFO) << " Peer Closed Data: " << peerClose->RemoteCloseReason; + } + GTEST_LOG_(INFO) << std::endl; + return; + } + else if (work->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto frame = work->AsTextFrame(); + GTEST_LOG_(INFO) << "Ignoring " << frame->Text << std::endl; + } + } + } +}; + +class LibWebSocketStatus { + +public: + std::string GetLWSStatus() + { + WebSocketOptions options; + + options.ServiceName = "websockettest"; + // Send 3 protocols to LWS. + options.Protocols.push_back("brownCow"); + options.Protocols.push_back("lws-status"); + options.Protocols.push_back("flibbityflobbidy"); + WebSocket serverSocket(Azure::Core::Url("wss://libwebsockets.org"), options); + serverSocket.Open(); + + // The server should have chosen the lws-status protocol since it doesn't understand the other + // protocols. + EXPECT_EQ("lws-status", serverSocket.GetNegotiatedProtocol()); + std::string returnValue; + std::shared_ptr lwsStatus; + do + { + + lwsStatus = serverSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::TextFrameReceived, lwsStatus->FrameType); + if (lwsStatus->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto textFrame = lwsStatus->AsTextFrame(); + returnValue.insert(returnValue.end(), textFrame->Text.begin(), textFrame->Text.end()); + } + } while (!lwsStatus->IsFinalFrame); + serverSocket.Close(); + return returnValue; + } +}; + +TEST_F(WebSocketTests, LibWebSocketOrgLwsStatus) +{ + { + LibWebSocketStatus lwsStatus; + auto serverStatus = lwsStatus.GetLWSStatus(); + GTEST_LOG_(INFO) << "Server status: " << serverStatus << std::endl; + + Azure::Core::Json::_internal::json status; + EXPECT_NO_THROW(status = Azure::Core::Json::_internal::json::parse(serverStatus)); + EXPECT_TRUE(status["conns"].is_array()); + auto& connections = status["conns"].get_ref&>(); + bool foundOurConnection = false; + + // Scan through the list of connections to find a connection from the websockettest. + for (auto& connection : connections) + { + EXPECT_TRUE(connection["ua"].is_string()); + auto userAgent = connection["ua"].get(); + if (userAgent.find("websockettest") != std::string::npos) + { + foundOurConnection = true; + break; + } + } + EXPECT_TRUE(foundOurConnection); + } +} +TEST_F(WebSocketTests, LibWebSocketOrgIncrement) +{ + { + LibWebSocketIncrementProtocol incrementProtocol; + incrementProtocol.Open(); + + // Note that we cannot practically validate the numbers received from the service because + // they may be in flight at the time the "Reset" call is made. + for (auto i = 0; i < 100; i += 1) + { + if (i % 5 == 0) + { + GTEST_LOG_(INFO) << "Reset" << std::endl; + incrementProtocol.Reset(); + } + int number = incrementProtocol.GetNextNumber(); + GTEST_LOG_(INFO) << "Got next number " << number << std::endl; + } + incrementProtocol.RequestClose(); + incrementProtocol.ConsumeUntilClosed(); + } +} +#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +TEST_F(WebSocketTests, CurlTransportCoverage) +{ + { + + Azure::Core::Http::WebSockets::CurlWebSocketTransportOptions transportOptions; + transportOptions.HttpKeepAlive = false; + auto transport + = std::make_shared(transportOptions); + + EXPECT_THROW(transport->NativeCloseSocket(1001, {}, {}), std::runtime_error); + EXPECT_THROW(transport->NativeGetCloseSocketInformation({}), std::runtime_error); + EXPECT_THROW( + transport->NativeSendFrame(WebSocketTransport::NativeWebSocketFrameType::Binary, {}, {}), + std::runtime_error); + EXPECT_THROW(transport->NativeReceiveFrame({}), std::runtime_error); + } +} +#endif