Use urlparse to generate websocket url

Fixes #246
This commit is contained in:
Sergi Almacellas Abellana 2017-06-07 11:05:11 +02:00
parent e245dbbd32
commit 9cd0b3dc9a
2 changed files with 49 additions and 10 deletions

View File

@ -19,8 +19,7 @@ import collections
from websocket import WebSocket, ABNF, enableTrace
import six
import ssl
from six.moves.urllib.parse import urlencode
from six.moves.urllib.parse import quote_plus
from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
STDIN_CHANNEL = 0
STDOUT_CHANNEL = 1
@ -203,18 +202,21 @@ class WSClient:
WSResponse = collections.namedtuple('WSResponse', ['data'])
def get_websocket_url(url):
parsed_url = urlparse(url)
parts = list(parsed_url)
if parsed_url.scheme == 'http':
parts[0] = 'ws'
elif parsed_url.scheme == 'https':
parts[0] = 'wss'
return urlunparse(parts)
def websocket_call(configuration, url, query_params, _request_timeout,
_preload_content, headers):
"""An internal function to be called in api-client when a websocket
connection is required."""
# switch protocols from http to websocket
url = url.replace('http://', 'ws://')
url = url.replace('https://', 'wss://')
# patch extra /
url = url.replace('//api', '/api')
# Extract the command from the list of tuples
commands = None
for key, value in query_params:
@ -238,7 +240,7 @@ def websocket_call(configuration, url, query_params, _request_timeout,
url += '&command=' + quote_plus(commands)
try:
client = WSClient(configuration, url, headers)
client = WSClient(configuration, get_websocket_url(url), headers)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)

View File

@ -0,0 +1,37 @@
# Copyright 2017 The Kubernetes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from .ws_client import get_websocket_url
class WSClientTest(unittest.TestCase):
def test_websocket_client(self):
for url, ws_url in [
('http://localhost/api', 'ws://localhost/api'),
('https://localhost/api', 'wss://localhost/api'),
('https://domain.com/api', 'wss://domain.com/api'),
('https://api.domain.com/api', 'wss://api.domain.com/api'),
('http://api.domain.com', 'ws://api.domain.com'),
('https://api.domain.com', 'wss://api.domain.com'),
('http://api.domain.com/', 'ws://api.domain.com/'),
('https://api.domain.com/', 'wss://api.domain.com/'),
]:
self.assertEqual(get_websocket_url(url), ws_url)
if __name__ == '__main__':
unittest.main()