Enable binary support for WSClient

Currently, under python 3, the WSClient decodes all data via UTF-8. This
will break, e.g. capturing the stdout of tar or gzip.
This adds a new 'binary' kwarg to the WSClient class and websocket_call
function. If this is set to true, then the decoding will not happen, and
all channels will be interpreted as binary.
This does raise a slight complication, as the OpenAPI-generated client
will convert the output to a string, no matter what, which it ends up
doing by (effectively) calling repr(). This requires a bit of magic to
recover the orignial bytes, and is inefficient. However, this is only
the case when using the default _preload_content=True, setting this to
False and manually calling read_all or read_channel, this issue does not
arise.
This commit is contained in:
Andrew Melnick 2024-02-14 15:06:55 -07:00
parent 7712421cdc
commit 488518d957
3 changed files with 73 additions and 14 deletions

View File

@ -30,9 +30,18 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa
except AttributeError:
configuration = api_client.config
prev_request = api_client.request
binary = kwargs.pop('binary', False)
try:
api_client.request = functools.partial(websocket_request, configuration)
return api_method(*args, **kwargs)
api_client.request = functools.partial(websocket_request, configuration, binary=binary)
out = api_method(*args, **kwargs)
# The api_client insists on converting this to a string using its representation, so we have
# to do this dance to strip it of the b' prefix and ' suffix, encode it byte-per-byte (latin1),
# escape all of the unicode \x*'s, then encode it back byte-by-byte
# However, if _preload_content=False is passed, then the entire WSClient is returned instead
# of a response, and we want to leave it alone
if binary and kwargs.get('_preload_content', True):
out = out[2:-1].encode('latin1').decode('unicode_escape').encode('latin1')
return out
finally:
api_client.request = prev_request

View File

@ -26,8 +26,9 @@ import time
import six
import yaml
from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO
from six import StringIO, BytesIO
from websocket import WebSocket, ABNF, enableTrace
from base64 import urlsafe_b64decode
@ -48,7 +49,7 @@ class _IgnoredIO:
class WSClient:
def __init__(self, configuration, url, headers, capture_all):
def __init__(self, configuration, url, headers, capture_all, binary=False):
"""A websocket client with support for channels.
Exec command uses different channels for different streams. for
@ -58,8 +59,10 @@ class WSClient:
"""
self._connected = False
self._channels = {}
self.binary = binary
self.newline = '\n' if not self.binary else b'\n'
if capture_all:
self._all = StringIO()
self._all = StringIO() if not self.binary else BytesIO()
else:
self._all = _IgnoredIO()
self.sock = create_websocket(configuration, url, headers)
@ -92,8 +95,8 @@ class WSClient:
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
if self.newline in data:
index = data.find(self.newline)
ret = data[:index]
data = data[index+1:]
if data:
@ -197,10 +200,12 @@ class WSClient:
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
if six.PY3 and not self.binary:
data = data.decode("utf-8", "replace")
if len(data) > 1:
channel = ord(data[0])
channel = data[0]
if six.PY3 and not self.binary:
channel = ord(channel)
data = data[1:]
if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
@ -518,13 +523,17 @@ def websocket_call(configuration, _method, url, **kwargs):
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)
binary = kwargs.get('binary', False)
try:
client = WSClient(configuration, url, headers, capture_all)
client = WSClient(configuration, url, headers, capture_all, binary=binary)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
all = client.read_all()
if binary:
return WSResponse(all)
else:
return WSResponse('%s' % ''.join(all))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

View File

@ -20,6 +20,8 @@ import time
import unittest
import uuid
import six
import io
import gzip
from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
@ -118,15 +120,28 @@ class TestClient(unittest.TestCase):
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, str)
self.assertEqual(3, len(resp.splitlines()))
exec_command = ['/bin/sh',
'-c',
'echo -n "This is a test string" | gzip']
resp = stream(api.connect_get_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False,
binary=True)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, bytes)
self.assertEqual("This is a test string", gzip.decompress(resp).decode('utf-8'))
exec_command = 'uptime'
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s' % repr(resp))
self.assertEqual(1, len(resp.splitlines()))
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
@ -154,6 +169,32 @@ class TestClient(unittest.TestCase):
resp.update(timeout=5)
self.assertFalse(resp.is_open())
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
binary=True,
_preload_content=False)
resp.write_stdin(b"echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual(b"test string 1", line)
resp.write_stdin(b"echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual(b"test string 2", line)
resp.write_stdin(b"exit\n")
resp.update(timeout=5)
while True:
line = resp.read_channel(ERROR_CHANNEL)
if len(line) != 0:
break
time.sleep(1)
status = json.loads(line)
self.assertEqual(status['status'], 'Success')
resp.update(timeout=5)
self.assertFalse(resp.is_open())
number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)