diff --git a/stream/stream.py b/stream/stream.py index 627fd1a33..9bb590172 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -12,19 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types +import functools from . import ws_client -def stream(func, *args, **kwargs): - """Stream given API call using websocket. - Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()""" - - api_client = func.__self__.api_client +def _websocket_reqeust(websocket_request, api_method, *args, **kwargs): + """Override the ApiClient.request method with an alternative websocket based + method and call the supplied Kubernetes API method with that in place.""" + api_client = api_method.__self__.api_client + # old generated code's api client has config. new ones has configuration + try: + configuration = api_client.configuration + except AttributeError: + configuration = api_client.config prev_request = api_client.request try: - api_client.request = types.MethodType(ws_client.websocket_call, api_client) - return func(*args, **kwargs) + api_client.request = functools.partial(websocket_request, configuration) + return api_method(*args, **kwargs) finally: api_client.request = prev_request + + +stream = functools.partial(_websocket_reqeust, ws_client.websocket_call) diff --git a/stream/ws_client.py b/stream/ws_client.py index 313003634..fa7f393e8 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -283,18 +283,9 @@ def create_websocket(configuration, url, headers=None): return websocket -def _configuration(api_client): - # old generated code's api client has config. new ones has - # configuration - try: - return api_client.configuration - except AttributeError: - return api_client.config - - -def websocket_call(api_client, _method, url, **kwargs): +def websocket_call(configuration, _method, url, **kwargs): """An internal function to be called in api-client when a websocket - connection is required. args and kwargs are the parameters of + connection is required. method, url, and kwargs are the parameters of apiClient.request method.""" url = get_websocket_url(url, kwargs.get("query_params")) @@ -304,7 +295,7 @@ def websocket_call(api_client, _method, url, **kwargs): capture_all = kwargs.get("capture_all", True) try: - client = WSClient(_configuration(api_client), url, headers, capture_all) + client = WSClient(configuration, url, headers, capture_all) if not _preload_content: return client client.run_forever(timeout=_request_timeout)