Merge pull request #2194 from meln5674/feature/binary-wsclient

Enable binary support for WSClient
This commit is contained in:
Kubernetes Prow Robot 2024-02-28 21:00:27 -08:00 committed by GitHub
commit 851dc2a0b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 14 deletions

View File

@ -30,9 +30,18 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa
except AttributeError:
configuration = api_client.config
prev_request = api_client.request
binary = kwargs.pop('binary', False)
try:
api_client.request = functools.partial(websocket_request, configuration)
return api_method(*args, **kwargs)
api_client.request = functools.partial(websocket_request, configuration, binary=binary)
out = api_method(*args, **kwargs)
# The api_client insists on converting this to a string using its representation, so we have
# to do this dance to strip it of the b' prefix and ' suffix, encode it byte-per-byte (latin1),
# escape all of the unicode \x*'s, then encode it back byte-by-byte
# However, if _preload_content=False is passed, then the entire WSClient is returned instead
# of a response, and we want to leave it alone
if binary and kwargs.get('_preload_content', True):
out = out[2:-1].encode('latin1').decode('unicode_escape').encode('latin1')
return out
finally:
api_client.request = prev_request

View File

@ -26,8 +26,9 @@ import time
import six
import yaml
from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO
from six import StringIO, BytesIO
from websocket import WebSocket, ABNF, enableTrace
from base64 import urlsafe_b64decode
@ -48,7 +49,7 @@ class _IgnoredIO:
class WSClient:
def __init__(self, configuration, url, headers, capture_all):
def __init__(self, configuration, url, headers, capture_all, binary=False):
"""A websocket client with support for channels.
Exec command uses different channels for different streams. for
@ -58,8 +59,10 @@ class WSClient:
"""
self._connected = False
self._channels = {}
self.binary = binary
self.newline = '\n' if not self.binary else b'\n'
if capture_all:
self._all = StringIO()
self._all = StringIO() if not self.binary else BytesIO()
else:
self._all = _IgnoredIO()
self.sock = create_websocket(configuration, url, headers)
@ -92,8 +95,8 @@ class WSClient:
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
if self.newline in data:
index = data.find(self.newline)
ret = data[:index]
data = data[index+1:]
if data:
@ -197,10 +200,12 @@ class WSClient:
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
if six.PY3 and not self.binary:
data = data.decode("utf-8", "replace")
if len(data) > 1:
channel = ord(data[0])
channel = data[0]
if six.PY3 and not self.binary:
channel = ord(channel)
data = data[1:]
if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
@ -518,13 +523,17 @@ def websocket_call(configuration, _method, url, **kwargs):
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)
binary = kwargs.get('binary', False)
try:
client = WSClient(configuration, url, headers, capture_all)
client = WSClient(configuration, url, headers, capture_all, binary=binary)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
all = client.read_all()
if binary:
return WSResponse(all)
else:
return WSResponse('%s' % ''.join(all))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

View File

@ -20,6 +20,8 @@ import time
import unittest
import uuid
import six
import io
import gzip
from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
@ -118,15 +120,28 @@ class TestClient(unittest.TestCase):
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, str)
self.assertEqual(3, len(resp.splitlines()))
exec_command = ['/bin/sh',
'-c',
'echo -n "This is a test string" | gzip']
resp = stream(api.connect_get_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False,
binary=True)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, bytes)
self.assertEqual("This is a test string", gzip.decompress(resp).decode('utf-8'))
exec_command = 'uptime'
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s' % repr(resp))
self.assertEqual(1, len(resp.splitlines()))
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
@ -154,6 +169,32 @@ class TestClient(unittest.TestCase):
resp.update(timeout=5)
self.assertFalse(resp.is_open())
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
binary=True,
_preload_content=False)
resp.write_stdin(b"echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual(b"test string 1", line)
resp.write_stdin(b"echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual(b"test string 2", line)
resp.write_stdin(b"exit\n")
resp.update(timeout=5)
while True:
line = resp.read_channel(ERROR_CHANNEL)
if len(line) != 0:
break
time.sleep(1)
status = json.loads(line)
self.assertEqual(status['status'], 'Success')
resp.update(timeout=5)
self.assertFalse(resp.is_open())
number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)