Merge pull request #125 from mbohlool/exec

Improvements on ws_client
This commit is contained in:
Mehdy Bohlool 2017-02-21 13:46:01 -08:00 committed by GitHub
commit 436351b027
6 changed files with 294 additions and 52 deletions

View File

@ -1,3 +1,6 @@
# v1.0.0b2
- Support exec calls in both interactive and non-interactive mode #58
# v1.0.0b1
- Support insecure-skip-tls-verify config flag #99

94
examples/exec.py Normal file
View File

@ -0,0 +1,94 @@
import time
from kubernetes import config
from kubernetes.client import configuration
from kubernetes.client.apis import core_v1_api
from kubernetes.client.rest import ApiException
config.load_kube_config()
configuration.assert_hostname = False
api = core_v1_api.CoreV1Api()
name = 'busybox-test'
resp = None
try:
resp = api.read_namespaced_pod(name=name,
namespace='default')
except ApiException as e:
if e.status != 404:
print("Unknown error: %s" % e)
exit(1)
if not resp:
print("Pod %s does not exits. Creating it..." % name)
pod_manifest = {
'apiVersion': 'v1',
'kind': 'Pod',
'metadata': {
'name': name
},
'spec': {
'containers': [{
'image': 'busybox',
'name': 'sleep',
"args": [
"/bin/sh",
"-c",
"while true;do date;sleep 5; done"
]
}]
}
}
resp = api.create_namespaced_pod(body=pod_manifest,
namespace='default')
while True:
resp = api.read_namespaced_pod(name=name,
namespace='default')
if resp.status.phase != 'Pending':
break
time.sleep(1)
print("Done.")
# calling exec and wait for response.
exec_command = [
'/bin/sh',
'-c',
'echo This message goes to stderr >&2; echo This message goes to stdout']
resp = api.connect_get_namespaced_pod_exec(name, 'default',
command=exec_command,
stderr=True, stdin=False,
stdout=True, tty=False)
print("Response: " + resp)
# Calling exec interactively.
exec_command = ['/bin/sh']
resp = api.connect_get_namespaced_pod_exec(name, 'default',
command=exec_command,
stderr=True, stdin=True,
stdout=True, tty=False,
_preload_content=False)
commands = [
"echo test1",
"echo \"This message goes to stderr\" >&2",
]
while resp.is_open():
resp.update(timeout=1)
if resp.peek_stdout():
print("STDOUT: %s" % resp.read_stdout())
if resp.peek_stderr():
print("STDERR: %s" % resp.read_stderr())
if commands:
c = commands.pop(0)
print("Running command... %s\n" % c)
resp.write_stdin(c + "\n")
else:
break
resp.write_stdin("date\n")
sdate = resp.readline_stdout(timeout=3)
print("Server date command returns: %s" % sdate)
resp.write_stdin("whoami\n")
user = resp.readline_stdout(timeout=3)
print("Server user is: %s" % user)

View File

@ -347,12 +347,12 @@ class ApiClient(object):
# FIXME(dims) : We need a better way to figure out which
# calls end up using web sockets
if url.endswith('/exec') and (method == "GET" or method == "POST"):
return ws_client.GET(self.config,
url,
query_params=query_params,
_request_timeout=_request_timeout,
headers=headers)
return ws_client.websocket_call(self.config,
url,
query_params=query_params,
_request_timeout=_request_timeout,
_preload_content=_preload_content,
headers=headers)
if method == "GET":
return self.rest_client.GET(url,
query_params=query_params,

View File

@ -12,33 +12,40 @@
from .rest import ApiException
import select
import certifi
import time
import collections
import websocket
from websocket import WebSocket, ABNF, enableTrace
import six
import ssl
from six.moves.urllib.parse import urlencode
from six.moves.urllib.parse import quote_plus
STDIN_CHANNEL = 0
STDOUT_CHANNEL = 1
STDERR_CHANNEL = 2
class WSClient:
def __init__(self, configuration, url, headers):
self.messages = []
self.errors = []
websocket.enableTrace(False)
header = None
"""A websocket client with support for channels.
Exec command uses different channels for different streams. for
example, 0 is stdin, 1 is stdout and 2 is stderr. Some other API calls
like port forwarding can forward different pods' streams to different
channels.
"""
enableTrace(False)
header = []
self._connected = False
self._channels = {}
self._all = ""
# We just need to pass the Authorization, ignore all the other
# http headers we get from the generated code
if 'Authorization' in headers:
header = "Authorization: %s" % headers['Authorization']
self.ws = websocket.WebSocketApp(url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
header=[header] if header else None)
self.ws.on_open = self.on_open
if headers and 'authorization' in headers:
header.append("authorization: %s" % headers['authorization'])
if url.startswith('wss://') and configuration.verify_ssl:
ssl_opts = {
@ -52,30 +59,145 @@ class WSClient:
else:
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
self.ws.run_forever(sslopt=ssl_opts)
self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
self.sock.connect(url, header=header)
self._connected = True
def on_message(self, ws, message):
if message[0] == '\x01':
message = message[1:]
if message:
if six.PY3 and isinstance(message, six.binary_type):
message = message.decode('utf-8')
self.messages.append(message)
def peek_channel(self, channel, timeout=0):
"""Peek a channel and return part of the input,
empty string otherwise."""
self.update(timeout=timeout)
if channel in self._channels:
return self._channels[channel]
return ""
def on_error(self, ws, error):
self.errors.append(error)
def read_channel(self, channel, timeout=0):
"""Read data from a channel."""
if channel not in self._channels:
ret = self.peek_channel(channel, timeout)
else:
ret = self._channels[channel]
if channel in self._channels:
del self._channels[channel]
return ret
def on_close(self, ws):
pass
def readline_channel(self, channel, timeout=None):
"""Read a line from a channel."""
if timeout is None:
timeout = float("inf")
start = time.time()
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
ret = data[:index]
data = data[index+1:]
if data:
self._channels[channel] = data
else:
del self._channels[channel]
return ret
self.update(timeout=(timeout - time.time() + start))
def on_open(self, ws):
pass
def write_channel(self, channel, data):
"""Write data to a channel."""
self.sock.send(chr(channel) + data)
def peek_stdout(self, timeout=0):
"""Same as peek_channel with channel=1."""
return self.peek_channel(STDOUT_CHANNEL, timeout=timeout)
def read_stdout(self, timeout=None):
"""Same as read_channel with channel=1."""
return self.read_channel(STDOUT_CHANNEL, timeout=timeout)
def readline_stdout(self, timeout=None):
"""Same as readline_channel with channel=1."""
return self.readline_channel(STDOUT_CHANNEL, timeout=timeout)
def peek_stderr(self, timeout=0):
"""Same as peek_channel with channel=2."""
return self.peek_channel(STDERR_CHANNEL, timeout=timeout)
def read_stderr(self, timeout=None):
"""Same as read_channel with channel=2."""
return self.read_channel(STDERR_CHANNEL, timeout=timeout)
def readline_stderr(self, timeout=None):
"""Same as readline_channel with channel=2."""
return self.readline_channel(STDERR_CHANNEL, timeout=timeout)
def read_all(self):
"""Read all of the inputs with the same order they recieved. The channel
information would be part of the string. This is useful for
non-interactive call where a set of command passed to the API call and
their result is needed after the call is concluded.
TODO: Maybe we can process this and return a more meaningful map with
channels mapped for each input.
"""
out = self._all
self._all = ""
self._channels = {}
return out
def is_open(self):
"""True if the connection is still alive."""
return self._connected
def write_stdin(self, data):
"""The same as write_channel with channel=0."""
self.write_channel(STDIN_CHANNEL, data)
def update(self, timeout=0):
"""Update channel buffers with at most one complete frame of input."""
if not self.is_open():
return
if not self.sock.connected:
self._connected = False
return
r, _, _ = select.select(
(self.sock.sock, ), (), (), timeout)
if r:
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
self._connected = False
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
data = data.decode("utf-8")
self._all += data
if len(data) > 1:
channel = ord(data[0])
data = data[1:]
if data:
if channel not in self._channels:
self._channels[channel] = data
else:
self._channels[channel] += data
def run_forever(self, timeout=None):
"""Wait till connection is closed or timeout reached. Buffer any input
received during this time."""
if timeout:
start = time.time()
while self.is_open() and time.time() - start < timeout:
self.update(timeout=(timeout - time.time() + start))
else:
while self.is_open():
self.update(timeout=None)
WSResponse = collections.namedtuple('WSResponse', ['data'])
def GET(configuration, url, query_params, _request_timeout, headers):
def websocket_call(configuration, url, query_params, _request_timeout,
_preload_content, headers):
"""An internal function to be called in api-client when a websocket
connection is required."""
# switch protocols from http to websocket
url = url.replace('http://', 'ws://')
url = url.replace('https://', 'wss://')
@ -105,10 +227,11 @@ def GET(configuration, url, query_params, _request_timeout, headers):
else:
url += '&command=' + quote_plus(commands)
client = WSClient(configuration, url, headers)
if client.errors:
raise ApiException(
status=0,
reason='\n'.join([str(error) for error in client.errors])
)
return WSResponse('%s' % ''.join(client.messages))
try:
client = WSClient(configuration, url, headers)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

View File

@ -42,4 +42,5 @@ def get_e2e_configuration():
if config.host is None:
raise unittest.SkipTest('Unable to find a running Kubernetes instance')
print('Running test against : %s' % config.host)
config.assert_hostname = False
return config

View File

@ -18,10 +18,14 @@ import uuid
from kubernetes.client import api_client
from kubernetes.client.apis import core_v1_api
from kubernetes.client.configuration import configuration
from kubernetes.e2e_test import base
def short_uuid():
id = str(uuid.uuid4())
return id[-12:]
class TestClient(unittest.TestCase):
@classmethod
@ -32,7 +36,7 @@ class TestClient(unittest.TestCase):
client = api_client.ApiClient(config=self.config)
api = core_v1_api.CoreV1Api(client)
name = 'busybox-test-' + str(uuid.uuid4())
name = 'busybox-test-' + short_uuid()
pod_manifest = {
'apiVersion': 'v1',
'kind': 'Pod',
@ -68,7 +72,7 @@ class TestClient(unittest.TestCase):
exec_command = ['/bin/sh',
'-c',
'for i in $(seq 1 3); do date; sleep 1; done']
'for i in $(seq 1 3); do date; done']
resp = api.connect_get_namespaced_pod_exec(name, 'default',
command=exec_command,
stderr=False, stdin=False,
@ -78,12 +82,29 @@ class TestClient(unittest.TestCase):
exec_command = 'uptime'
resp = api.connect_post_namespaced_pod_exec(name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
self.assertEqual(1, len(resp.splitlines()))
resp = api.connect_post_namespaced_pod_exec(name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
_preload_content=False)
resp.write_stdin("echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual("test string 1", line)
resp.write_stdin("echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual("test string 2", line)
resp.write_stdin("exit\n")
resp.update(timeout=5)
self.assertFalse(resp.is_open())
number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)
@ -94,7 +115,7 @@ class TestClient(unittest.TestCase):
client = api_client.ApiClient(config=self.config)
api = core_v1_api.CoreV1Api(client)
name = 'frontend-' + str(uuid.uuid4())
name = 'frontend-' + short_uuid()
service_manifest = {'apiVersion': 'v1',
'kind': 'Service',
'metadata': {'labels': {'name': name},
@ -133,7 +154,7 @@ class TestClient(unittest.TestCase):
client = api_client.ApiClient(config=self.config)
api = core_v1_api.CoreV1Api(client)
name = 'frontend-' + str(uuid.uuid4())
name = 'frontend-' + short_uuid()
rc_manifest = {
'apiVersion': 'v1',
'kind': 'ReplicationController',
@ -166,7 +187,7 @@ class TestClient(unittest.TestCase):
client = api_client.ApiClient(config=self.config)
api = core_v1_api.CoreV1Api(client)
name = 'test-configmap-' + str(uuid.uuid4())
name = 'test-configmap-' + short_uuid()
test_configmap = {
"kind": "ConfigMap",
"apiVersion": "v1",
@ -195,7 +216,7 @@ class TestClient(unittest.TestCase):
resp = api.delete_namespaced_config_map(
name=name, body={}, namespace='default')
resp = api.list_namespaced_config_map('kube-system', pretty=True)
resp = api.list_namespaced_config_map('default', pretty=True)
self.assertEqual([], resp.items)
def test_node_apis(self):