Some support for running broker sequence tests over websockets

pull/2768/head
Roger A. Light 3 years ago
parent 946261571f
commit 71e87c7561

@ -25,12 +25,14 @@ class SingleMsg(object):
self.comment = comment
class MsgSequence(object):
__slots__ = 'name', 'msgs', 'msgs_all', 'expect_disconnect'
__slots__ = 'name', 'msgs', 'msgs_all', 'expect_disconnect', 'port', 'protocol'
def __init__(self, name, default_connect=True, proto_ver=4, expect_disconnect=True):
def __init__(self, name, default_connect=True, port=1888, protocol='mqtt', proto_ver=4, expect_disconnect=True):
self.name = name
self.msgs_all = deque()
self.expect_disconnect = expect_disconnect
self.port = port
self.protocol = protocol
if default_connect:
self.add_default_connect(proto_ver=proto_ver)
@ -67,7 +69,7 @@ class MsgSequence(object):
sock.send(msg.message)
def _publish_message(self, msg):
sock = mosq_test.client_connect_only(hostname="localhost", port=1888, timeout=2)
sock = mosq_test.client_connect_only(hostname="localhost", port=self.port, timeout=2, protocol=self.protocol)
sock.send(mosq_test.gen_connect("helper"))
mosq_test.expect_packet(sock, "connack", mosq_test.gen_connack(rc=0))
@ -95,7 +97,7 @@ class MsgSequence(object):
data = sock.recv(1)
if len(data) == 1 and self.expect_disconnect:
raise ValueError("Still connected")
except ConnectionResetError:
except (ConnectionResetError, BlockingIOError):
if self.expect_disconnect:
pass
else:
@ -127,7 +129,7 @@ class MsgSequence(object):
self._connected_check(sock)
def do_test(hostname, port):
def do_test(hostname, port, protocol):
data_path=Path(__file__).resolve().parent/"data"
rc = 0
sequences = []
@ -184,6 +186,8 @@ def do_test(hostname, port):
expect_disconnect = g_expect_disconnect
this_test = MsgSequence(tname,
port=port,
protocol=protocol,
proto_ver=proto_ver,
expect_disconnect=expect_disconnect,
default_connect=connect)
@ -203,7 +207,7 @@ def do_test(hostname, port):
total += 1
try:
failed_tests.append(this_test)
sock = mosq_test.client_connect_only(hostname=hostname, port=port, timeout=2)
sock = mosq_test.client_connect_only(hostname=hostname, port=port, timeout=2, protocol=protocol)
this_test.process_all(sock)
print("\033[32m" + tname + "\033[0m")
succeeded += 1
@ -221,6 +225,7 @@ def do_test(hostname, port):
print("\033[31m" + tname + " failed: " + str(e) + "\033[0m")
rc = 1
sock.close()
exit()
except mosq_test.TestError as e:
print("\033[31m" + tname + " failed: " + str(e) + "\033[0m")
rc = 1
@ -230,8 +235,9 @@ def do_test(hostname, port):
if False:
for t in failed_tests:
try:
sock = mosq_test.client_connect_only(hostname=hostname, port=port, timeout=2)
sock = mosq_test.client_connect_only(hostname=hostname, port=port, timeout=2, protocol=protocol)
t.process_all(sock)
length = len(data)
print("\033[32m" + t.name + "\033[0m")
sock.close()
except ValueError as e:
@ -254,19 +260,35 @@ def do_test(hostname, port):
print("%d tests total\n%d tests succeeded" % (total, succeeded))
return rc
hostname = "localhost"
port = mosq_test.get_port()
broker = mosq_test.start_broker(filename=os.path.basename(__file__), port=port, nolog=True)
rc = 0
try:
rc = do_test(hostname=hostname, port=port)
finally:
broker.terminate()
if mosq_test.wait_for_subprocess(broker):
print("broker not terminated")
if rc == 0: rc=1
(stdo, stde) = broker.communicate()
if rc:
#print(stde.decode('utf-8'))
exit(rc)
def write_config(filename, port, protocol):
with open(filename, 'w') as f:
f.write(f'listener {port}\n')
f.write(f'protocol {protocol}\n')
f.write("allow_anonymous true\n")
f.write("log_type all\n")
def main(protocol):
hostname = "localhost"
port = mosq_test.get_port()
conf_file = 'msg_sequence_test.conf'
write_config(conf_file, port, protocol)
broker = mosq_test.start_broker(filename=conf_file, port=port, use_conf=True, nolog=True)
rc = 0
try:
rc = do_test(hostname=hostname, port=port, protocol=protocol)
finally:
broker.terminate()
os.remove(conf_file)
if mosq_test.wait_for_subprocess(broker):
print("broker not terminated")
if rc == 0: rc=1
(stdo, stde) = broker.communicate()
if rc:
#print(stde.decode('utf-8'))
exit(rc)
#main(protocol="websockets")
main(protocol="mqtt")

@ -1,11 +1,14 @@
import atexit
import base64
import errno
import hashlib
import os
import socket
import subprocess
import struct
import sys
import time
import uuid
import traceback
@ -64,7 +67,7 @@ def start_broker(filename, cmd=None, port=0, use_conf=False, expect_fail=False,
elif os.environ.get('MOSQ_USE_VALGRIND') == 'failgrind':
cmd = ['fg-helper'] + cmd
else:
cmd = ['valgrind', '-q', '--trace-children=yes', '--leak-check=full', '--show-leak-kinds=all', '--log-file='+logfile] + cmd
cmd = ['valgrind', '-q', '--track-fds=yes', '--trace-children=yes', '--leak-check=full', '--show-leak-kinds=all', '--log-file='+logfile] + cmd
vg_logfiles.append(logfile)
vg_index += 1
delay = 1
@ -249,10 +252,17 @@ def do_receive_send(sock, receive_packet, send_packet, error_string="receive sen
raise ValueError
def client_connect_only(hostname="localhost", port=1888, timeout=10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
def client_connect_only(hostname="localhost", port=1888, timeout=10, protocol="mqtt"):
if protocol == "websockets":
addr = (hostname, port)
sock = socket.create_connection(addr, timeout=timeout)
sock.settimeout(timeout)
sock = WebsocketWrapper(sock, hostname, port, False, "/mqtt", None)
#sock.setblocking(0)
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
return sock
def client_connect_only_unix(path, timeout=10):
@ -261,8 +271,8 @@ def client_connect_only_unix(path, timeout=10):
sock.connect(path)
return sock
def do_client_connect(connect_packet, connack_packet, hostname="localhost", port=1888, timeout=10, connack_error="connack"):
sock = client_connect_only(hostname, port, timeout)
def do_client_connect(connect_packet, connack_packet, hostname="localhost", port=1888, timeout=10, connack_error="connack", protocol="mqtt"):
sock = client_connect_only(hostname, port, timeout, protocol)
return do_send_receive(sock, connect_packet, connack_packet, connack_error)
@ -884,6 +894,318 @@ def client_test(client_cmd, client_args, callback, cb_data):
exit(rc)
# =============================================
# Websockets wrapper
# =============================================
class WebsocketConnectionError(ValueError):
pass
class WebsocketWrapper(object):
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CONNCLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
def __init__(self, socket, host, port, is_ssl, path, extra_headers):
self.connected = False
self._ssl = is_ssl
self._host = host
self._port = port
self._socket = socket
self._path = path
self._sendbuffer = bytearray()
self._readbuffer = bytearray()
self._requested_size = 0
self._payload_head = 0
self._readbuffer_head = 0
self._do_handshake(extra_headers)
def __del__(self):
self._sendbuffer = None
self._readbuffer = None
def _do_handshake(self, extra_headers):
sec_websocket_key = uuid.uuid4().bytes
sec_websocket_key = base64.b64encode(sec_websocket_key)
websocket_headers = {
"Host": "{self._host:s}:{self._port:d}".format(self=self),
"Upgrade": "websocket",
"Connection": "Upgrade",
"Origin": "https://{self._host:s}:{self._port:d}".format(self=self),
"Sec-WebSocket-Key": sec_websocket_key.decode("utf8"),
"Sec-Websocket-Version": "13",
"Sec-Websocket-Protocol": "mqtt",
}
# This is checked in ws_set_options so it will either be None, a
# dictionary, or a callable
if isinstance(extra_headers, dict):
websocket_headers.update(extra_headers)
elif callable(extra_headers):
websocket_headers = extra_headers(websocket_headers)
header = "\r\n".join([
"GET {self._path} HTTP/1.1".format(self=self),
"\r\n".join("{}: {}".format(i, j)
for i, j in websocket_headers.items()),
"\r\n",
]).encode("utf8")
self._socket.send(header)
has_secret = False
has_upgrade = False
while True:
# read HTTP response header as lines
byte = self._socket.recv(1)
self._readbuffer.extend(byte)
# line end
if byte == b"\n":
if len(self._readbuffer) > 2:
# check upgrade
if b"connection" in str(self._readbuffer).lower().encode('utf-8'):
if b"upgrade" not in str(self._readbuffer).lower().encode('utf-8'):
raise WebsocketConnectionError(
"WebSocket handshake error, connection not upgraded")
else:
has_upgrade = True
# check key hash
if b"sec-websocket-accept" in str(self._readbuffer).lower().encode('utf-8'):
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
server_hash = self._readbuffer.decode(
'utf-8').split(": ", 1)[1]
server_hash = server_hash.strip().encode('utf-8')
client_hash = sec_websocket_key.decode('utf-8') + GUID
client_hash = hashlib.sha1(client_hash.encode('utf-8'))
client_hash = base64.b64encode(client_hash.digest())
if server_hash != client_hash:
raise WebsocketConnectionError(
"WebSocket handshake error, invalid secret key")
else:
has_secret = True
else:
# ending linebreak
break
# reset linebuffer
self._readbuffer = bytearray()
# connection reset
elif not byte:
raise WebsocketConnectionError("WebSocket handshake error")
if not has_upgrade or not has_secret:
raise WebsocketConnectionError("WebSocket handshake error")
self._readbuffer = bytearray()
self.connected = True
def _create_frame(self, opcode, data, do_masking=1):
header = bytearray()
length = len(data)
mask_key = bytearray(os.urandom(4))
mask_flag = do_masking
# 1 << 7 is the final flag, we don't send continuated data
header.append(1 << 7 | opcode)
if length < 126:
header.append(mask_flag << 7 | length)
elif length < 65536:
header.append(mask_flag << 7 | 126)
header += struct.pack("!H", length)
elif length < 0x8000000000000001:
header.append(mask_flag << 7 | 127)
header += struct.pack("!Q", length)
else:
raise ValueError("Maximum payload size is 2^63")
if mask_flag == 1:
for index in range(length):
data[index] ^= mask_key[index % 4]
data = mask_key + data
return header + data
def _buffered_read(self, length):
# try to recv and store needed bytes
wanted_bytes = length - (len(self._readbuffer) - self._readbuffer_head)
if wanted_bytes > 0:
data = self._socket.recv(wanted_bytes)
if not data:
raise ConnectionAbortedError
else:
self._readbuffer.extend(data)
if len(data) < wanted_bytes:
print(f"{len(data)} {wanted_bytes}")
raise BlockingIOError
self._readbuffer_head += length
return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head]
def _recv_impl(self, length):
# try to decode websocket payload part from data
try:
self._readbuffer_head = 0
result = None
chunk_startindex = self._payload_head
chunk_endindex = self._payload_head + length
header1 = self._buffered_read(1)
header2 = self._buffered_read(1)
opcode = (header1[0] & 0x0f)
maskbit = (header2[0] & 0x80) == 0x80
lengthbits = (header2[0] & 0x7f)
payload_length = lengthbits
mask_key = None
# read length
if lengthbits == 0x7e:
value = self._buffered_read(2)
payload_length, = struct.unpack("!H", value)
elif lengthbits == 0x7f:
value = self._buffered_read(8)
payload_length, = struct.unpack("!Q", value)
# read mask
if maskbit:
mask_key = self._buffered_read(4)
# if frame payload is shorter than the requested data, read only the possible part
readindex = chunk_endindex
if payload_length < readindex:
readindex = payload_length
if readindex > 0:
# get payload chunk
payload = self._buffered_read(readindex)
# unmask only the needed part
if maskbit:
for index in range(chunk_startindex, readindex):
payload[index] ^= mask_key[index % 4]
result = payload[chunk_startindex:readindex]
self._payload_head = readindex
else:
payload = bytearray()
# check if full frame arrived and reset readbuffer and payloadhead if needed
if readindex == payload_length:
self._readbuffer = bytearray()
self._payload_head = 0
# respond to non-binary opcodes, their arrival is not guaranteed beacause of non-blocking sockets
if opcode == WebsocketWrapper.OPCODE_CONNCLOSE:
frame = self._create_frame(
WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0)
self._socket.send(frame)
if opcode == WebsocketWrapper.OPCODE_PING:
frame = self._create_frame(
WebsocketWrapper.OPCODE_PONG, payload, 0)
self._socket.send(frame)
# This isn't *proper* handling of continuation frames, but given
# that we only support binary frames, it is *probably* good enough.
if (opcode == WebsocketWrapper.OPCODE_BINARY or opcode == WebsocketWrapper.OPCODE_CONTINUATION) \
and payload_length > 0:
return result
else:
#raise BlockingIOError
return b""
except ConnectionError:
self.connected = False
return b''
def _send_impl(self, data):
# if previous frame was sent successfully
if len(self._sendbuffer) == 0:
# create websocket frame
frame = self._create_frame(
WebsocketWrapper.OPCODE_BINARY, bytearray(data))
self._sendbuffer.extend(frame)
self._requested_size = len(data)
# try to write out as much as possible
length = self._socket.send(self._sendbuffer)
self._sendbuffer = self._sendbuffer[length:]
if len(self._sendbuffer) == 0:
# buffer sent out completely, return with payload's size
return self._requested_size
else:
# couldn't send whole data, request the same data again with 0 as sent length
return 0
def recv(self, length):
return self._recv_impl(length)
def read(self, length):
return self._recv_impl(length)
def send(self, data):
return self._send_impl(data)
def write(self, data):
return self._send_impl(data)
def close(self):
self._socket.close()
def fileno(self):
return self._socket.fileno()
def pending(self):
# Fix for bug #131: a SSL socket may still have data available
# for reading without select() being aware of it.
if self._ssl:
return self._socket.pending()
else:
# normal socket rely only on select()
return 0
def setblocking(self, flag):
self._socket.setblocking(flag)
@atexit.register
def test_cleanup():
global vg_logfiles

Loading…
Cancel
Save