Merge pull request #210 from iciclespider/port-forward
Implement port forwarding.
This commit is contained in:
commit
3dc7fe0b92
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .stream import stream
|
||||
from .stream import stream, portforward
|
||||
|
||||
@ -17,9 +17,12 @@ import functools
|
||||
from . import ws_client
|
||||
|
||||
|
||||
def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
|
||||
def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs):
|
||||
"""Override the ApiClient.request method with an alternative websocket based
|
||||
method and call the supplied Kubernetes API method with that in place."""
|
||||
if force_kwargs:
|
||||
for kwarg, value in force_kwargs.items():
|
||||
kwargs[kwarg] = value
|
||||
api_client = api_method.__self__.api_client
|
||||
# old generated code's api client has config. new ones has configuration
|
||||
try:
|
||||
@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
|
||||
api_client.request = prev_request
|
||||
|
||||
|
||||
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call)
|
||||
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None)
|
||||
portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False})
|
||||
|
||||
@ -12,12 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from kubernetes.client.rest import ApiException
|
||||
from kubernetes.client.rest import ApiException, ApiValueError
|
||||
|
||||
import certifi
|
||||
import collections
|
||||
import select
|
||||
import socket
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
@ -225,6 +227,174 @@ class WSClient:
|
||||
WSResponse = collections.namedtuple('WSResponse', ['data'])
|
||||
|
||||
|
||||
class PortForward:
|
||||
def __init__(self, websocket, ports):
|
||||
"""A websocket client with support for port forwarding.
|
||||
|
||||
Port Forward command sends on 2 channels per port, a read/write
|
||||
data channel and a read only error channel. Both channels are sent an
|
||||
initial frame contaning the port number that channel is associated with.
|
||||
"""
|
||||
|
||||
self.websocket = websocket
|
||||
self.local_ports = {}
|
||||
for ix, port_number in enumerate(ports):
|
||||
self.local_ports[port_number] = self._Port(ix, port_number)
|
||||
# There is a thread run per PortForward instance which performs the translation between the
|
||||
# raw socket data sent by the python application and the websocket protocol. This thread
|
||||
# terminates after either side has closed all ports, and after flushing all pending data.
|
||||
proxy = threading.Thread(
|
||||
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
|
||||
target=self._proxy
|
||||
)
|
||||
proxy.daemon = True
|
||||
proxy.start()
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.websocket.connected
|
||||
|
||||
def socket(self, port_number):
|
||||
if port_number not in self.local_ports:
|
||||
raise ValueError("Invalid port number")
|
||||
return self.local_ports[port_number].socket
|
||||
|
||||
def error(self, port_number):
|
||||
if port_number not in self.local_ports:
|
||||
raise ValueError("Invalid port number")
|
||||
return self.local_ports[port_number].error
|
||||
|
||||
def close(self):
|
||||
for port in self.local_ports.values():
|
||||
port.socket.close()
|
||||
|
||||
class _Port:
|
||||
def __init__(self, ix, port_number):
|
||||
# The remote port number
|
||||
self.port_number = port_number
|
||||
# The websocket channel byte number for this port
|
||||
self.channel = six.int2byte(ix * 2)
|
||||
# A socket pair is created to provide a means of translating the data flow
|
||||
# between the python application and the kubernetes websocket. The self.python
|
||||
# half of the socket pair is used by the _proxy method to receive and send data
|
||||
# to the running python application.
|
||||
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
# The self.socket half of the pair is used by the python application to send
|
||||
# and receive data to the eventual pod port. It is wrapped in the _Socket class
|
||||
# because a socket pair is an AF_UNIX socket, not a AF_INET socket. This allows
|
||||
# intercepting setting AF_INET socket options that would error against an AF_UNIX
|
||||
# socket.
|
||||
self.socket = self._Socket(s)
|
||||
# Data accumulated from the websocket to be sent to the python application.
|
||||
self.data = b''
|
||||
# All data sent from kubernetes on the port error channel.
|
||||
self.error = None
|
||||
|
||||
class _Socket:
|
||||
def __init__(self, socket):
|
||||
self._socket = socket
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._socket, name)
|
||||
|
||||
def setsockopt(self, level, optname, value):
|
||||
# The following socket option is not valid with a socket created from socketpair,
|
||||
# and is set by the http.client.HTTPConnection.connect method.
|
||||
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
|
||||
return
|
||||
self._socket.setsockopt(level, optname, value)
|
||||
|
||||
# Proxy all socket data between the python code and the kubernetes websocket.
|
||||
def _proxy(self):
|
||||
channel_ports = []
|
||||
channel_initialized = []
|
||||
local_ports = {}
|
||||
for port in self.local_ports.values():
|
||||
# Setup the data channel for this port number
|
||||
channel_ports.append(port)
|
||||
channel_initialized.append(False)
|
||||
# Setup the error channel for this port number
|
||||
channel_ports.append(port)
|
||||
channel_initialized.append(False)
|
||||
port.python.setblocking(True)
|
||||
local_ports[port.python] = port
|
||||
# The data to send on the websocket socket
|
||||
kubernetes_data = b''
|
||||
while True:
|
||||
rlist = [] # List of sockets to read from
|
||||
wlist = [] # List of sockets to write to
|
||||
if self.websocket.connected:
|
||||
rlist.append(self.websocket)
|
||||
if kubernetes_data:
|
||||
wlist.append(self.websocket)
|
||||
local_all_closed = True
|
||||
for port in self.local_ports.values():
|
||||
if port.python.fileno() != -1:
|
||||
if port.error or not self.websocket.connected:
|
||||
if port.data:
|
||||
wlist.append(port.python)
|
||||
local_all_closed = False
|
||||
else:
|
||||
port.python.close()
|
||||
else:
|
||||
rlist.append(port.python)
|
||||
if port.data:
|
||||
wlist.append(port.python)
|
||||
local_all_closed = False
|
||||
if local_all_closed and not (self.websocket.connected and kubernetes_data):
|
||||
self.websocket.close()
|
||||
return
|
||||
r, w, _ = select.select(rlist, wlist, [])
|
||||
for sock in r:
|
||||
if sock == self.websocket:
|
||||
opcode, frame = self.websocket.recv_data_frame(True)
|
||||
if opcode == ABNF.OPCODE_BINARY:
|
||||
if not frame.data:
|
||||
raise RuntimeError("Unexpected frame data size")
|
||||
channel = six.byte2int(frame.data)
|
||||
if channel >= len(channel_ports):
|
||||
raise RuntimeError("Unexpected channel number: %s" % channel)
|
||||
port = channel_ports[channel]
|
||||
if channel_initialized[channel]:
|
||||
if channel % 2:
|
||||
if port.error is None:
|
||||
port.error = ''
|
||||
port.error += frame.data[1:].decode()
|
||||
else:
|
||||
port.data += frame.data[1:]
|
||||
else:
|
||||
if len(frame.data) != 3:
|
||||
raise RuntimeError(
|
||||
"Unexpected initial channel frame data size"
|
||||
)
|
||||
port_number = six.byte2int(frame.data[1:2]) + (six.byte2int(frame.data[2:3]) * 256)
|
||||
if port_number != port.port_number:
|
||||
raise RuntimeError(
|
||||
"Unexpected port number in initial channel frame: %s" % port_number
|
||||
)
|
||||
channel_initialized[channel] = True
|
||||
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
|
||||
raise RuntimeError("Unexpected websocket opcode: %s" % opcode)
|
||||
else:
|
||||
port = local_ports[sock]
|
||||
data = port.python.recv(1024 * 1024)
|
||||
if data:
|
||||
kubernetes_data += ABNF.create_frame(
|
||||
port.channel + data,
|
||||
ABNF.OPCODE_BINARY,
|
||||
).format()
|
||||
else:
|
||||
port.python.close()
|
||||
for sock in w:
|
||||
if sock == self.websocket:
|
||||
sent = self.websocket.sock.send(kubernetes_data)
|
||||
kubernetes_data = kubernetes_data[sent:]
|
||||
else:
|
||||
port = local_ports[sock]
|
||||
sent = port.python.send(port.data)
|
||||
port.data = port.data[sent:]
|
||||
|
||||
|
||||
def get_websocket_url(url, query_params=None):
|
||||
parsed_url = urlparse(url)
|
||||
parts = list(parsed_url)
|
||||
@ -302,3 +472,36 @@ def websocket_call(configuration, _method, url, **kwargs):
|
||||
return WSResponse('%s' % ''.join(client.read_all()))
|
||||
except (Exception, KeyboardInterrupt, SystemExit) as e:
|
||||
raise ApiException(status=0, reason=str(e))
|
||||
|
||||
|
||||
def portforward_call(configuration, _method, url, **kwargs):
|
||||
"""An internal function to be called in api-client when a websocket
|
||||
connection is required for port forwarding. args and kwargs are the
|
||||
parameters of apiClient.request method."""
|
||||
|
||||
query_params = kwargs.get("query_params")
|
||||
|
||||
ports = []
|
||||
for param, value in query_params:
|
||||
if param == 'ports':
|
||||
for port in value.split(','):
|
||||
try:
|
||||
port_number = int(port)
|
||||
except ValueError:
|
||||
raise ApiValueError("Invalid port number: %s" % port)
|
||||
if not (0 < port_number < 65536):
|
||||
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
|
||||
if port_number in ports:
|
||||
raise ApiValueError("Duplicate port numbers: %s" % port)
|
||||
ports.append(port_number)
|
||||
if not ports:
|
||||
raise ApiValueError("Missing required parameter `ports`")
|
||||
|
||||
url = get_websocket_url(url, query_params)
|
||||
headers = kwargs.get("headers")
|
||||
|
||||
try:
|
||||
websocket = create_websocket(configuration, url, headers)
|
||||
return PortForward(websocket, ports)
|
||||
except (Exception, KeyboardInterrupt, SystemExit) as e:
|
||||
raise ApiException(status=0, reason=str(e))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user