Rework the parsing of the requested ports to support both a local port and a remote port.

This commit is contained in:
Patrick J. McNerthney 2020-09-01 18:33:33 -10:00
parent cc9ae10549
commit 72e372599d

View File

@ -237,30 +237,30 @@ class PortForward:
"""
self.websocket = websocket
self.ports = {}
for ix, port_number in enumerate(ports):
self.ports[port_number] = self._Port(ix, port_number)
self.local_ports = {}
for ix, local_remote in enumerate(ports):
self.local_ports[local_remote[0]] = self._Port(ix, local_remote[1])
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
).start()
def socket(self, port_number):
if port_number not in self.ports:
def socket(self, local_number):
if local_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.ports[port_number].socket
return self.local_ports[local_number].socket
def error_channel(self, port_number):
if port_number not in self.ports:
def error(self, local_number):
if local_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.ports[port_number].error
return self.local_ports[local_number].error
def close(self):
for port in self.ports.values():
for port in self.local_ports.values():
port.socket.close()
class _Port:
def __init__(self, ix, number):
self.number = number
def __init__(self, ix, remote_number):
self.remote_number = remote_number
self.channel = bytes([ix * 2])
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket = self._Socket(s)
@ -287,7 +287,7 @@ class PortForward:
channel_initialized = []
python_ports = {}
rlist = []
for port in self.ports.values():
for port in self.local_ports.values():
# Setup the data channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
@ -300,7 +300,7 @@ class PortForward:
kubernetes_data = b''
while True:
wlist = []
for port in self.ports.values():
for port in self.local_ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
@ -318,7 +318,7 @@ class PortForward:
if s == self.websocket.sock:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.ports.values():
for port in self.local_ports.values():
port.python.close()
return
if opcode == ABNF.OPCODE_BINARY:
@ -330,11 +330,9 @@ class PortForward:
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
port.error = frame.data[1:].decode()
if port.python in rlist:
port.python.close()
rlist.remove(port.python)
port.data = b''
if port.error is None:
port.error = ''
port.error += frame.data[1:].decode()
else:
port.data += frame.data[1:]
else:
@ -343,7 +341,7 @@ class PortForward:
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.number:
if port_number != port.remote_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
@ -453,17 +451,38 @@ def portforward_call(configuration, _method, url, **kwargs):
query_params = kwargs.get("query_params")
ports = []
for key, value in query_params:
if key == 'ports':
for port in value.split(','):
for ix in range(len(query_params)):
if query_params[ix][0] == 'ports':
remote_ports = []
for port in query_params[ix][1].split(','):
try:
# The last specified port is the remote port
port = int(port.split(':')[-1])
if not (0 < port < 65536):
local_remote = port.split(':')
if len(local_remote) > 2:
raise ValueError
ports.append(port)
if len(local_remote) == 1:
local_remote[0] = int(local_remote[0])
if not (0 < local_remote[0] < 65536):
raise ValueError
local_remote.append(local_remote[0])
elif len(local_remote) == 2:
if local_remote[0]:
local_remote[0] = int(local_remote[0])
if not (0 <= local_remote[0] < 65536):
raise ValueError
else:
local_remote[0] = 0
local_remote[1] = int(local_remote[1])
if not (0 < local_remote[1] < 65536):
raise ValueError
if not local_remote[0]:
local_remote[0] = len(ports) + 1
else:
raise ValueError
ports.append(local_remote)
remote_ports.append(str(local_remote[1]))
except ValueError:
raise ApiValueError("Invalid port number `" + str(port) + "`")
raise ApiValueError("Invalid port number `" + port + "`")
query_params[ix] = ('ports', ','.join(remote_ports))
if not ports:
raise ApiValueError("Missing required parameter `ports`")