Implement port forwarding.

This commit is contained in:
Patrick J. McNerthney 2020-08-23 13:34:41 -10:00
parent 471a67844e
commit 74d0e292b8
3 changed files with 178 additions and 4 deletions

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .stream import stream
from .stream import stream, portforward

View File

@ -17,9 +17,12 @@ import functools
from . import ws_client
def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs):
"""Override the ApiClient.request method with an alternative websocket based
method and call the supplied Kubernetes API method with that in place."""
if force_kwargs:
for kwarg, value in force_kwargs.items():
kwargs[kwarg] = value
api_client = api_method.__self__.api_client
# old generated code's api client has config. new ones has configuration
try:
@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
api_client.request = prev_request
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call)
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None)
portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False})

View File

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from kubernetes.client.rest import ApiException
from kubernetes.client.rest import ApiException, ApiValueError
import certifi
import collections
import select
import socket
import ssl
import threading
import time
import six
@ -225,6 +227,143 @@ class WSClient:
WSResponse = collections.namedtuple('WSResponse', ['data'])
class PortForward:
def __init__(self, websocket, ports):
"""A websocket client with support for port forwarding.
Port Forward command sends on 2 channels per port, a read/write
data channel and a read only error channel. Both channels are sent an
initial frame contaning the port number that channel is associated with.
"""
self.websocket = websocket
self.ports = {}
for ix, port_number in enumerate(ports):
self.ports[port_number] = self._Port(ix, port_number)
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
).start()
def socket(self, port_number):
if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].socket
def error(self, port_number):
if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].error
def close(self):
for port in self.ports.values():
port.socket.close()
class _Port:
def __init__(self, ix, number):
self.number = number
self.channel = bytes([ix * 2])
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket = self._Socket(s)
self.data = b''
self.error = None
class _Socket:
def __init__(self, socket):
self._socket = socket
def __getattr__(self, name):
return getattr(self._socket, name)
def setsockopt(self, level, optname, value):
# The following socket option is not valid with a socket created from socketpair,
# and is set when creating an SSLSocket from this socket.
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
return
self._socket.setsockopt(level, optname, value)
# Proxy all socket data between the python code and the kubernetes websocket.
def _proxy(self):
channel_ports = []
channel_initialized = []
python_ports = {}
rlist = []
for port in self.ports.values():
channel_ports.append(port)
channel_initialized.append(False)
channel_ports.append(port)
channel_initialized.append(False)
python_ports[port.python] = port
rlist.append(port.python)
rlist.append(self.websocket.sock)
kubernetes_data = b''
while True:
wlist = []
for port in self.ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
wlist.append(self.websocket.sock)
r, w, _ = select.select(rlist, wlist, [])
for s in w:
if s == self.websocket.sock:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = python_ports[s]
sent = port.python.send(port.data)
port.data = port.data[sent:]
for s in r:
if s == self.websocket.sock:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.ports.values():
port.python.close()
return
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = frame.data[0]
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: " + str(channel))
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
port.error = frame.data[1:].decode()
if port.python in rlist:
port.python.close()
rlist.remove(port.python)
port.data = b''
else:
port.data += frame.data[1:]
else:
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
else:
port = python_ports[s]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
rlist.remove(s)
if len(rlist) == 1:
self.websocket.close()
return
def get_websocket_url(url, query_params=None):
parsed_url = urlparse(url)
parts = list(parsed_url)
@ -302,3 +441,34 @@ def websocket_call(configuration, _method, url, **kwargs):
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))
def portforward_call(configuration, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required for port forwarding. args and kwargs are the
parameters of apiClient.request method."""
query_params = kwargs.get("query_params")
ports = []
for key, value in query_params:
if key == 'ports':
for port in value.split(','):
try:
port = int(port)
if not (0 < port < 65536):
raise ValueError
ports.append(port)
except ValueError:
raise ApiValueError("Invalid port number `" + str(port) + "`")
if not ports:
raise ApiValueError("Missing required parameter `ports`")
url = get_websocket_url(url, query_params)
headers = kwargs.get("headers")
try:
websocket = create_websocket(configuration, url, headers)
return PortForward(websocket, ports)
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))