Rework how the PortForward._proxy thread determines when and how to terminate.

This commit is contained in:
Patrick J. McNerthney 2020-09-06 09:25:58 -10:00
parent 72e372599d
commit 7bf04b384b

View File

@ -238,33 +238,51 @@ class PortForward:
self.websocket = websocket
self.local_ports = {}
for ix, local_remote in enumerate(ports):
self.local_ports[local_remote[0]] = self._Port(ix, local_remote[1])
for ix, port_number in enumerate(ports):
self.local_ports[port_number] = self._Port(ix, port_number)
# There is a thread run per PortForward instance which performs the translation between the
# raw socket data sent by the python application and the websocket protocol. This thread
# terminates after either side has closed all ports, and after flushing all pending data.
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
target=self._proxy,
daemon=True
).start()
def socket(self, local_number):
if local_number not in self.local_ports:
def socket(self, port_number):
if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[local_number].socket
return self.local_ports[port_number].socket
def error(self, local_number):
if local_number not in self.local_ports:
def error(self, port_number):
if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[local_number].error
return self.local_ports[port_number].error
def close(self):
for port in self.local_ports.values():
port.socket.close()
class _Port:
def __init__(self, ix, remote_number):
self.remote_number = remote_number
def __init__(self, ix, port_number):
# The remote port number
self.port_number = port_number
# The websocket channel byte number for this port
self.channel = bytes([ix * 2])
# A socket pair is created to provide a means of translating the data flow
# between the python application and the kubernetes websocket. The self.python
# half of the socket pair is used by the _proxy method to receive and send data
# to the running python application.
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
# The self.socket half of the pair is used by the python application to send
# and receive data to the eventual pod port. It is wrapped in the _Socket class
# because a socket pair is an AF_UNIX socket, not a AF_NET socket. This allows
# intercepting setting AF_INET socket options that would error against an AD_UNIX
# socket.
self.socket = self._Socket(s)
# Data accumulated from the websocket to be sent to the python application.
self.data = b''
# All data sent from kubernetes on the port error channel.
self.error = None
class _Socket:
@ -285,8 +303,7 @@ class PortForward:
def _proxy(self):
channel_ports = []
channel_initialized = []
python_ports = {}
rlist = []
local_ports = {}
for port in self.local_ports.values():
# Setup the data channel for this port number
channel_ports.append(port)
@ -294,33 +311,36 @@ class PortForward:
# Setup the error channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
python_ports[port.python] = port
rlist.append(port.python)
rlist.append(self.websocket.sock)
port.python.setblocking(True)
local_ports[port.python] = port
# The data to send on the websocket socket
kubernetes_data = b''
while True:
wlist = []
rlist = [] # List of sockets to read from
wlist = [] # List of sockets to write to
if self.websocket.connected:
rlist.append(self.websocket)
if kubernetes_data:
wlist.append(self.websocket)
all_closed = True
for port in self.local_ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
wlist.append(self.websocket.sock)
r, w, _ = select.select(rlist, wlist, [])
for s in w:
if s == self.websocket.sock:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = python_ports[s]
sent = port.python.send(port.data)
port.data = port.data[sent:]
for s in r:
if s == self.websocket.sock:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.local_ports.values():
if port.python.fileno() != -1:
if port.data:
wlist.append(port.python)
all_closed = False
else:
if self.websocket.connected:
rlist.append(port.python)
all_closed = False
else:
port.python.close()
return
if all_closed and (not self.websocket.connected or not kubernetes_data):
self.websocket.close()
return
r, w, _ = select.select(rlist, wlist, [])
for sock in r:
if sock == self.websocket:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
@ -341,15 +361,15 @@ class PortForward:
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.remote_number:
if port_number != port.port_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
else:
port = python_ports[s]
port = local_ports[sock]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
@ -357,11 +377,16 @@ class PortForward:
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
rlist.remove(s)
if len(rlist) == 1:
self.websocket.close()
return
if not port.data:
port.python.close()
for sock in w:
if sock == self.websocket:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = local_ports[sock]
sent = port.python.send(port.data)
port.data = port.data[sent:]
def get_websocket_url(url, query_params=None):
@ -451,38 +476,18 @@ def portforward_call(configuration, _method, url, **kwargs):
query_params = kwargs.get("query_params")
ports = []
for ix in range(len(query_params)):
if query_params[ix][0] == 'ports':
remote_ports = []
for port in query_params[ix][1].split(','):
for param, value in query_params:
if param == 'ports':
for port in value.split(','):
try:
local_remote = port.split(':')
if len(local_remote) > 2:
raise ValueError
if len(local_remote) == 1:
local_remote[0] = int(local_remote[0])
if not (0 < local_remote[0] < 65536):
raise ValueError
local_remote.append(local_remote[0])
elif len(local_remote) == 2:
if local_remote[0]:
local_remote[0] = int(local_remote[0])
if not (0 <= local_remote[0] < 65536):
raise ValueError
else:
local_remote[0] = 0
local_remote[1] = int(local_remote[1])
if not (0 < local_remote[1] < 65536):
raise ValueError
if not local_remote[0]:
local_remote[0] = len(ports) + 1
else:
raise ValueError
ports.append(local_remote)
remote_ports.append(str(local_remote[1]))
port_number = int(port)
except ValueError:
raise ApiValueError("Invalid port number `" + port + "`")
query_params[ix] = ('ports', ','.join(remote_ports))
raise ApiValueError("Invalid port number: %s" % port)
if not (0 < port_number < 65536):
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
if port_number in ports:
raise ApiValueError("Duplicate port numbers: %s" % port)
ports.append(port_number)
if not ports:
raise ApiValueError("Missing required parameter `ports`")