You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
259 lines
8.9 KiB
Python
259 lines
8.9 KiB
Python
import os
|
|
from pathlib import Path
|
|
import sqlite3
|
|
import mosq_test
|
|
|
|
dir_in = 0
|
|
dir_out = 1
|
|
|
|
ms_invalid = 0
|
|
ms_publish_qos0 = 1
|
|
ms_publish_qos1 = 2
|
|
ms_wait_for_puback = 3
|
|
ms_publish_qos2 = 4
|
|
ms_wait_for_pubrec = 5
|
|
ms_resend_pubrel = 6
|
|
ms_wait_for_pubrel = 7
|
|
ms_resend_pubcomp = 8
|
|
ms_wait_for_pubcomp = 9
|
|
ms_send_pubrec = 10
|
|
ms_queued = 11
|
|
|
|
def write_config(filename, port):
|
|
with open(filename, 'w') as f:
|
|
f.write("listener %d\n" % (port))
|
|
f.write("allow_anonymous true\n")
|
|
f.write(f"plugin {mosq_test.get_build_root()}/plugins/persist-sqlite/mosquitto_persist_sqlite.so\n")
|
|
f.write("plugin_opt_db_file %d/mosquitto.sqlite3\n" % (port))
|
|
|
|
|
|
def init(port):
|
|
try:
|
|
os.mkdir(str(port))
|
|
except FileExistsError:
|
|
pass
|
|
|
|
|
|
def cleanup(port):
|
|
rc = 1
|
|
try:
|
|
os.remove(f"{port}/mosquitto.sqlite3")
|
|
except FileNotFoundError:
|
|
pass
|
|
try:
|
|
os.rmdir(f"{port}")
|
|
rc = 0
|
|
except OSError as e:
|
|
print(f"ERROR sqlite3 file not removed after shutdown")
|
|
if Path(str(port), "mosquitto.sqlite3-wal").stat().st_size == 0:
|
|
# some versions of sqlite3 do not remove the wal file
|
|
# thus we make sure that the file is at least empty (no pending db transactions)
|
|
rc = 0
|
|
try:
|
|
os.remove(f"{port}/mosquitto.sqlite3-shm")
|
|
except FileNotFoundError:
|
|
pass
|
|
try:
|
|
os.remove(f"{port}/mosquitto.sqlite3-wal")
|
|
except FileNotFoundError:
|
|
pass
|
|
os.rmdir(f"{port}")
|
|
return rc
|
|
|
|
|
|
def check_counts(port, clients=0, client_msgs_in=0, client_msgs_out=0, base_msgs=0, retain_msgs=0, subscriptions=0):
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT COUNT(*) FROM clients')
|
|
row = cur.fetchone()
|
|
if row[0] != clients:
|
|
raise ValueError("Found %d clients, expected %d" % (row[0], clients))
|
|
|
|
cur.execute('SELECT COUNT(*) FROM client_msgs WHERE direction=0')
|
|
row = cur.fetchone()
|
|
if row[0] != client_msgs_in:
|
|
raise ValueError("Found %d client_msgs_in, expected %d" % (row[0], client_msgs_in))
|
|
|
|
cur.execute('SELECT COUNT(*) FROM client_msgs WHERE direction=1')
|
|
row = cur.fetchone()
|
|
if row[0] != client_msgs_out:
|
|
raise ValueError("Found %d client_msgs_out, expected %d" % (row[0], client_msgs_out))
|
|
|
|
cur.execute('SELECT COUNT(*) FROM subscriptions')
|
|
row = cur.fetchone()
|
|
if row[0] != subscriptions:
|
|
raise ValueError("Found %d subscriptions, expected %d" % (row[0], subscriptions))
|
|
|
|
cur.execute('SELECT COUNT(*) FROM base_msgs')
|
|
row = cur.fetchone()
|
|
if row[0] != base_msgs:
|
|
raise ValueError("Found %d base_msgs, expected %d" % (row[0], base_msgs))
|
|
|
|
cur.execute('SELECT COUNT(*) FROM retains')
|
|
row = cur.fetchone()
|
|
if row[0] != retain_msgs:
|
|
raise ValueError("Found %d retain_msgs, expected %d" % (row[0], retain_msgs))
|
|
con.close()
|
|
|
|
|
|
def check_client(port, client_id, username, will_delay_time, session_expiry_time,
|
|
listener_port, max_packet_size, max_qos, retain_available,
|
|
session_expiry_interval, will_delay_interval):
|
|
|
|
# "Fix" the infinite session expiry interval as mangled by an int32 conversion.
|
|
if session_expiry_interval == 4294967295:
|
|
session_expiry_interval = -1
|
|
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT client_id, username, will_delay_time, session_expiry_time, ' +
|
|
'listener_port, max_packet_size, max_qos, retain_available, ' +
|
|
'session_expiry_interval, will_delay_interval ' +
|
|
'FROM clients')
|
|
row = cur.fetchone()
|
|
|
|
if row[0] != client_id:
|
|
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
|
|
|
|
if username is not None and row[1] != username:
|
|
raise ValueError("Invalid username %s / %s" % (row[1], username))
|
|
|
|
if (will_delay_time == 0 and row[2] != 0) or (will_delay_time != 0 and row[2] == 0):
|
|
raise ValueError("Invalid will_delay_time %d / %d" % (row[2], will_delay_time))
|
|
|
|
if (session_expiry_time == 0 and row[3] != 0) or (session_expiry_time != 0 and row[3] == 0):
|
|
raise ValueError("Invalid session_expiry_time %d / %d" % (row[3], session_expiry_time))
|
|
|
|
if listener_port is not None and row[4] != listener_port:
|
|
raise ValueError("Invalid listener_port %d / %d" % (row[4], listener_port))
|
|
|
|
if row[5] != max_packet_size:
|
|
raise ValueError("Invalid max_packet_size %d / %d" % (row[5], max_packet_size))
|
|
|
|
if row[6] != max_qos:
|
|
raise ValueError("Invalid max_qos %d / %d" % (row[6], max_qos))
|
|
|
|
if row[7] != retain_available:
|
|
raise ValueError("Invalid retain_available %d / %d" % (row[7], retain_available))
|
|
|
|
if row[8] != session_expiry_interval:
|
|
raise ValueError("Invalid session_expiry_interval %d / %d" % (row[8], session_expiry_interval))
|
|
|
|
if row[9] != will_delay_interval:
|
|
raise ValueError("Invalid will_delay_interval %d / %d" % (row[9], will_delay_interval))
|
|
con.close()
|
|
|
|
|
|
def check_subscription(port, client_id, topic, subscription_options, subscription_identifier):
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT client_id, topic, subscription_options, subscription_identifier ' +
|
|
'FROM subscriptions')
|
|
row = cur.fetchone()
|
|
|
|
if row[0] != client_id:
|
|
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
|
|
|
|
if row[1] != topic:
|
|
raise ValueError("Invalid topic %s / %s" % (row[1], topic))
|
|
|
|
if row[2] != subscription_options:
|
|
raise ValueError("Invalid subscription_options %d / %d" % (row[2], subscription_options))
|
|
|
|
if row[3] != subscription_identifier:
|
|
raise ValueError("Invalid subscription_identifier %d / %d" % (row[3], subscription_identifier))
|
|
con.close()
|
|
|
|
|
|
def check_client_msg(port, client_id, store_id, dup, direction, mid, qos, retain, state):
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT client_id,store_id,dup,direction,mid,qos,retain,state ' +
|
|
'FROM client_msgs')
|
|
row = cur.fetchone()
|
|
|
|
if row[0] != client_id:
|
|
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
|
|
|
|
if row[1] != store_id:
|
|
raise ValueError("Invalid store_id %d / %d" % (row[1], store_id))
|
|
|
|
if row[2] != dup:
|
|
raise ValueError("Invalid dup %d / %d" % (row[2], dup))
|
|
|
|
if row[3] != direction:
|
|
raise ValueError("Invalid direction %d / %d" % (row[3], direction))
|
|
|
|
if row[4] != mid:
|
|
raise ValueError("Invalid mid %d / %d" % (row[4], mid))
|
|
|
|
if row[5] != qos:
|
|
raise ValueError("Invalid qos %d / %d" % (row[5], qos))
|
|
|
|
if row[6] != retain:
|
|
raise ValueError("Invalid retain %d / %d" % (row[6], retain))
|
|
|
|
if row[7] != state:
|
|
raise ValueError("Invalid state %d / %d" % (row[7], state))
|
|
con.close()
|
|
|
|
|
|
def check_base_msg(port, expiry_time, topic, payload, source_id, source_username,
|
|
payloadlen, source_mid, source_port, qos, retain, idx=0):
|
|
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT store_id,expiry_time,topic,payload,source_id,source_username, ' +
|
|
'payloadlen, source_mid, source_port, qos, retain ' +
|
|
'FROM base_msgs')
|
|
|
|
for i in range(0, idx+1):
|
|
row = cur.fetchone()
|
|
|
|
if row[0] == 0:
|
|
raise ValueError("Invalid store_id %d / %d" % (row[0], store_id))
|
|
|
|
if (expiry_time == 0 and row[1] != 0) or (expiry_time != 0 and row[1] == 0):
|
|
raise ValueError("Invalid expiry_time %d / %d" % (row[1], expiry_time))
|
|
|
|
if row[2] != topic:
|
|
raise ValueError("Invalid topic %s / %s" % (row[2], topic))
|
|
|
|
if row[3] != payload:
|
|
raise ValueError("Invalid payload %s / %s" % (row[3], payload))
|
|
|
|
if row[4] != source_id:
|
|
raise ValueError("Invalid source_id %s / %s" % (row[4], source_id))
|
|
|
|
if row[5] != source_username:
|
|
raise ValueError("Invalid source_username %s / %s" % (row[5], source_username))
|
|
|
|
if row[6] != payloadlen or (payloadlen != 0 and row[6] != len(row[3])):
|
|
raise ValueError("Invalid payloadlen %d / %d" % (row[6], payloadlen))
|
|
|
|
if row[7] != source_mid:
|
|
raise ValueError("Invalid source_mid %d / %d" % (row[7], source_mid))
|
|
|
|
if row[8] != source_port:
|
|
raise ValueError("Invalid source_port %d / %d" % (row[8], source_port))
|
|
|
|
if row[9] != qos:
|
|
raise ValueError("Invalid qos %d / %d" % (row[9], qos))
|
|
|
|
if row[10] != retain:
|
|
raise ValueError("Invalid retain %d / %d" % (row[10], retain))
|
|
|
|
con.close()
|
|
return row[0]
|
|
|
|
|
|
def check_retain(port, topic, store_id):
|
|
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
|
|
cur = con.cursor()
|
|
cur.execute('SELECT store_id FROM retains WHERE topic=?', (topic,))
|
|
row = cur.fetchone()
|
|
|
|
if row[0] != store_id:
|
|
raise ValueError("Invalid store_id %d / %d" % (row[0], store_id))
|
|
con.close()
|