From 7bf04b384b8cfcdba6387cf61e1cd9d6052669ee Mon Sep 17 00:00:00 2001 From: "Patrick J. McNerthney" Date: Sun, 6 Sep 2020 09:25:58 -1000 Subject: [PATCH] Rework how the PortForward._proxy thread determines when and how to terminate. --- stream/ws_client.py | 153 +++++++++++++++++++++++--------------------- 1 file changed, 79 insertions(+), 74 deletions(-) diff --git a/stream/ws_client.py b/stream/ws_client.py index 971ab6b48..fafba79a6 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -238,33 +238,51 @@ class PortForward: self.websocket = websocket self.local_ports = {} - for ix, local_remote in enumerate(ports): - self.local_ports[local_remote[0]] = self._Port(ix, local_remote[1]) + for ix, port_number in enumerate(ports): + self.local_ports[port_number] = self._Port(ix, port_number) + # There is a thread run per PortForward instance which performs the translation between the + # raw socket data sent by the python application and the websocket protocol. This thread + # terminates after either side has closed all ports, and after flushing all pending data. threading.Thread( - name="Kubernetes port forward proxy", target=self._proxy, daemon=True + name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]), + target=self._proxy, + daemon=True ).start() - def socket(self, local_number): - if local_number not in self.local_ports: + def socket(self, port_number): + if port_number not in self.local_ports: raise ValueError("Invalid port number") - return self.local_ports[local_number].socket + return self.local_ports[port_number].socket - def error(self, local_number): - if local_number not in self.local_ports: + def error(self, port_number): + if port_number not in self.local_ports: raise ValueError("Invalid port number") - return self.local_ports[local_number].error + return self.local_ports[port_number].error def close(self): for port in self.local_ports.values(): port.socket.close() class _Port: - def __init__(self, ix, remote_number): - self.remote_number = remote_number + def __init__(self, ix, port_number): + # The remote port number + self.port_number = port_number + # The websocket channel byte number for this port self.channel = bytes([ix * 2]) + # A socket pair is created to provide a means of translating the data flow + # between the python application and the kubernetes websocket. The self.python + # half of the socket pair is used by the _proxy method to receive and send data + # to the running python application. s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + # The self.socket half of the pair is used by the python application to send + # and receive data to the eventual pod port. It is wrapped in the _Socket class + # because a socket pair is an AF_UNIX socket, not a AF_NET socket. This allows + # intercepting setting AF_INET socket options that would error against an AD_UNIX + # socket. self.socket = self._Socket(s) + # Data accumulated from the websocket to be sent to the python application. self.data = b'' + # All data sent from kubernetes on the port error channel. self.error = None class _Socket: @@ -285,8 +303,7 @@ class PortForward: def _proxy(self): channel_ports = [] channel_initialized = [] - python_ports = {} - rlist = [] + local_ports = {} for port in self.local_ports.values(): # Setup the data channel for this port number channel_ports.append(port) @@ -294,33 +311,36 @@ class PortForward: # Setup the error channel for this port number channel_ports.append(port) channel_initialized.append(False) - python_ports[port.python] = port - rlist.append(port.python) - rlist.append(self.websocket.sock) + port.python.setblocking(True) + local_ports[port.python] = port + # The data to send on the websocket socket kubernetes_data = b'' while True: - wlist = [] + rlist = [] # List of sockets to read from + wlist = [] # List of sockets to write to + if self.websocket.connected: + rlist.append(self.websocket) + if kubernetes_data: + wlist.append(self.websocket) + all_closed = True for port in self.local_ports.values(): - if port.data: - wlist.append(port.python) - if kubernetes_data: - wlist.append(self.websocket.sock) - r, w, _ = select.select(rlist, wlist, []) - for s in w: - if s == self.websocket.sock: - sent = self.websocket.sock.send(kubernetes_data) - kubernetes_data = kubernetes_data[sent:] - else: - port = python_ports[s] - sent = port.python.send(port.data) - port.data = port.data[sent:] - for s in r: - if s == self.websocket.sock: - opcode, frame = self.websocket.recv_data_frame(True) - if opcode == ABNF.OPCODE_CLOSE: - for port in self.local_ports.values(): + if port.python.fileno() != -1: + if port.data: + wlist.append(port.python) + all_closed = False + else: + if self.websocket.connected: + rlist.append(port.python) + all_closed = False + else: port.python.close() - return + if all_closed and (not self.websocket.connected or not kubernetes_data): + self.websocket.close() + return + r, w, _ = select.select(rlist, wlist, []) + for sock in r: + if sock == self.websocket: + opcode, frame = self.websocket.recv_data_frame(True) if opcode == ABNF.OPCODE_BINARY: if not frame.data: raise RuntimeError("Unexpected frame data size") @@ -341,15 +361,15 @@ class PortForward: "Unexpected initial channel frame data size" ) port_number = frame.data[1] + (frame.data[2] * 256) - if port_number != port.remote_number: + if port_number != port.port_number: raise RuntimeError( "Unexpected port number in initial channel frame: " + str(port_number) ) channel_initialized[channel] = True - elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG): + elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE): raise RuntimeError("Unexpected websocket opcode: " + str(opcode)) else: - port = python_ports[s] + port = local_ports[sock] data = port.python.recv(1024 * 1024) if data: kubernetes_data += ABNF.create_frame( @@ -357,11 +377,16 @@ class PortForward: ABNF.OPCODE_BINARY, ).format() else: - port.python.close() - rlist.remove(s) - if len(rlist) == 1: - self.websocket.close() - return + if not port.data: + port.python.close() + for sock in w: + if sock == self.websocket: + sent = self.websocket.sock.send(kubernetes_data) + kubernetes_data = kubernetes_data[sent:] + else: + port = local_ports[sock] + sent = port.python.send(port.data) + port.data = port.data[sent:] def get_websocket_url(url, query_params=None): @@ -451,38 +476,18 @@ def portforward_call(configuration, _method, url, **kwargs): query_params = kwargs.get("query_params") ports = [] - for ix in range(len(query_params)): - if query_params[ix][0] == 'ports': - remote_ports = [] - for port in query_params[ix][1].split(','): + for param, value in query_params: + if param == 'ports': + for port in value.split(','): try: - local_remote = port.split(':') - if len(local_remote) > 2: - raise ValueError - if len(local_remote) == 1: - local_remote[0] = int(local_remote[0]) - if not (0 < local_remote[0] < 65536): - raise ValueError - local_remote.append(local_remote[0]) - elif len(local_remote) == 2: - if local_remote[0]: - local_remote[0] = int(local_remote[0]) - if not (0 <= local_remote[0] < 65536): - raise ValueError - else: - local_remote[0] = 0 - local_remote[1] = int(local_remote[1]) - if not (0 < local_remote[1] < 65536): - raise ValueError - if not local_remote[0]: - local_remote[0] = len(ports) + 1 - else: - raise ValueError - ports.append(local_remote) - remote_ports.append(str(local_remote[1])) + port_number = int(port) except ValueError: - raise ApiValueError("Invalid port number `" + port + "`") - query_params[ix] = ('ports', ','.join(remote_ports)) + raise ApiValueError("Invalid port number: %s" % port) + if not (0 < port_number < 65536): + raise ApiValueError("Port number must be between 0 and 65536: %s" % port) + if port_number in ports: + raise ApiValueError("Duplicate port numbers: %s" % port) + ports.append(port_number) if not ports: raise ApiValueError("Missing required parameter `ports`")