You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mosquitto/test/broker/persist_sqlite.py

259 lines
8.9 KiB
Python

import os
from pathlib import Path
import sqlite3
import mosq_test
dir_in = 0
dir_out = 1
ms_invalid = 0
ms_publish_qos0 = 1
ms_publish_qos1 = 2
ms_wait_for_puback = 3
ms_publish_qos2 = 4
ms_wait_for_pubrec = 5
ms_resend_pubrel = 6
ms_wait_for_pubrel = 7
ms_resend_pubcomp = 8
ms_wait_for_pubcomp = 9
ms_send_pubrec = 10
ms_queued = 11
def write_config(filename, port):
with open(filename, 'w') as f:
f.write("listener %d\n" % (port))
f.write("allow_anonymous true\n")
f.write(f"plugin {mosq_test.get_build_root()}/plugins/persist-sqlite/mosquitto_persist_sqlite.so\n")
f.write("plugin_opt_db_file %d/mosquitto.sqlite3\n" % (port))
def init(port):
try:
os.mkdir(str(port))
except FileExistsError:
pass
def cleanup(port):
rc = 1
try:
os.remove(f"{port}/mosquitto.sqlite3")
except FileNotFoundError:
pass
try:
os.rmdir(f"{port}")
rc = 0
except OSError as e:
print(f"ERROR sqlite3 file not removed after shutdown")
if Path(str(port), "mosquitto.sqlite3-wal").stat().st_size == 0:
# some versions of sqlite3 do not remove the wal file
# thus we make sure that the file is at least empty (no pending db transactions)
rc = 0
try:
os.remove(f"{port}/mosquitto.sqlite3-shm")
except FileNotFoundError:
pass
try:
os.remove(f"{port}/mosquitto.sqlite3-wal")
except FileNotFoundError:
pass
os.rmdir(f"{port}")
return rc
def check_counts(port, clients=0, client_msgs_in=0, client_msgs_out=0, base_msgs=0, retain_msgs=0, subscriptions=0):
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT COUNT(*) FROM clients')
row = cur.fetchone()
if row[0] != clients:
raise ValueError("Found %d clients, expected %d" % (row[0], clients))
cur.execute('SELECT COUNT(*) FROM client_msgs WHERE direction=0')
row = cur.fetchone()
if row[0] != client_msgs_in:
raise ValueError("Found %d client_msgs_in, expected %d" % (row[0], client_msgs_in))
cur.execute('SELECT COUNT(*) FROM client_msgs WHERE direction=1')
row = cur.fetchone()
if row[0] != client_msgs_out:
raise ValueError("Found %d client_msgs_out, expected %d" % (row[0], client_msgs_out))
cur.execute('SELECT COUNT(*) FROM subscriptions')
row = cur.fetchone()
if row[0] != subscriptions:
raise ValueError("Found %d subscriptions, expected %d" % (row[0], subscriptions))
cur.execute('SELECT COUNT(*) FROM base_msgs')
row = cur.fetchone()
if row[0] != base_msgs:
raise ValueError("Found %d base_msgs, expected %d" % (row[0], base_msgs))
cur.execute('SELECT COUNT(*) FROM retains')
row = cur.fetchone()
if row[0] != retain_msgs:
raise ValueError("Found %d retain_msgs, expected %d" % (row[0], retain_msgs))
con.close()
def check_client(port, client_id, username, will_delay_time, session_expiry_time,
listener_port, max_packet_size, max_qos, retain_available,
session_expiry_interval, will_delay_interval):
# "Fix" the infinite session expiry interval as mangled by an int32 conversion.
if session_expiry_interval == 4294967295:
session_expiry_interval = -1
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT client_id, username, will_delay_time, session_expiry_time, ' +
'listener_port, max_packet_size, max_qos, retain_available, ' +
'session_expiry_interval, will_delay_interval ' +
'FROM clients')
row = cur.fetchone()
if row[0] != client_id:
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
if username is not None and row[1] != username:
raise ValueError("Invalid username %s / %s" % (row[1], username))
if (will_delay_time == 0 and row[2] != 0) or (will_delay_time != 0 and row[2] == 0):
raise ValueError("Invalid will_delay_time %d / %d" % (row[2], will_delay_time))
if (session_expiry_time == 0 and row[3] != 0) or (session_expiry_time != 0 and row[3] == 0):
raise ValueError("Invalid session_expiry_time %d / %d" % (row[3], session_expiry_time))
if listener_port is not None and row[4] != listener_port:
raise ValueError("Invalid listener_port %d / %d" % (row[4], listener_port))
if row[5] != max_packet_size:
raise ValueError("Invalid max_packet_size %d / %d" % (row[5], max_packet_size))
if row[6] != max_qos:
raise ValueError("Invalid max_qos %d / %d" % (row[6], max_qos))
if row[7] != retain_available:
raise ValueError("Invalid retain_available %d / %d" % (row[7], retain_available))
if row[8] != session_expiry_interval:
raise ValueError("Invalid session_expiry_interval %d / %d" % (row[8], session_expiry_interval))
if row[9] != will_delay_interval:
raise ValueError("Invalid will_delay_interval %d / %d" % (row[9], will_delay_interval))
con.close()
def check_subscription(port, client_id, topic, subscription_options, subscription_identifier):
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT client_id, topic, subscription_options, subscription_identifier ' +
'FROM subscriptions')
row = cur.fetchone()
if row[0] != client_id:
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
if row[1] != topic:
raise ValueError("Invalid topic %s / %s" % (row[1], topic))
if row[2] != subscription_options:
raise ValueError("Invalid subscription_options %d / %d" % (row[2], subscription_options))
if row[3] != subscription_identifier:
raise ValueError("Invalid subscription_identifier %d / %d" % (row[3], subscription_identifier))
con.close()
def check_client_msg(port, client_id, store_id, dup, direction, mid, qos, retain, state):
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT client_id,store_id,dup,direction,mid,qos,retain,state ' +
'FROM client_msgs')
row = cur.fetchone()
if row[0] != client_id:
raise ValueError("Invalid client_id %s / %s" % (row[0], client_id))
if row[1] != store_id:
raise ValueError("Invalid store_id %d / %d" % (row[1], store_id))
if row[2] != dup:
raise ValueError("Invalid dup %d / %d" % (row[2], dup))
if row[3] != direction:
raise ValueError("Invalid direction %d / %d" % (row[3], direction))
if row[4] != mid:
raise ValueError("Invalid mid %d / %d" % (row[4], mid))
if row[5] != qos:
raise ValueError("Invalid qos %d / %d" % (row[5], qos))
if row[6] != retain:
raise ValueError("Invalid retain %d / %d" % (row[6], retain))
if row[7] != state:
raise ValueError("Invalid state %d / %d" % (row[7], state))
con.close()
def check_base_msg(port, expiry_time, topic, payload, source_id, source_username,
payloadlen, source_mid, source_port, qos, retain, idx=0):
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT store_id,expiry_time,topic,payload,source_id,source_username, ' +
'payloadlen, source_mid, source_port, qos, retain ' +
'FROM base_msgs')
for i in range(0, idx+1):
row = cur.fetchone()
if row[0] == 0:
raise ValueError("Invalid store_id %d / %d" % (row[0], store_id))
if (expiry_time == 0 and row[1] != 0) or (expiry_time != 0 and row[1] == 0):
raise ValueError("Invalid expiry_time %d / %d" % (row[1], expiry_time))
if row[2] != topic:
raise ValueError("Invalid topic %s / %s" % (row[2], topic))
if row[3] != payload:
raise ValueError("Invalid payload %s / %s" % (row[3], payload))
if row[4] != source_id:
raise ValueError("Invalid source_id %s / %s" % (row[4], source_id))
if row[5] != source_username:
raise ValueError("Invalid source_username %s / %s" % (row[5], source_username))
if row[6] != payloadlen or (payloadlen != 0 and row[6] != len(row[3])):
raise ValueError("Invalid payloadlen %d / %d" % (row[6], payloadlen))
if row[7] != source_mid:
raise ValueError("Invalid source_mid %d / %d" % (row[7], source_mid))
if row[8] != source_port:
raise ValueError("Invalid source_port %d / %d" % (row[8], source_port))
if row[9] != qos:
raise ValueError("Invalid qos %d / %d" % (row[9], qos))
if row[10] != retain:
raise ValueError("Invalid retain %d / %d" % (row[10], retain))
con.close()
return row[0]
def check_retain(port, topic, store_id):
con = sqlite3.connect(f"{port}/mosquitto.sqlite3")
cur = con.cursor()
cur.execute('SELECT store_id FROM retains WHERE topic=?', (topic,))
row = cur.fetchone()
if row[0] != store_id:
raise ValueError("Invalid store_id %d / %d" % (row[0], store_id))
con.close()