Refactor handle__connect() ahead of extended auth changes.

pull/1239/head
Roger A. Light 7 years ago
parent 636d0f1f74
commit fe854d3a64

@ -272,6 +272,7 @@ struct mosquitto {
# endif # endif
# endif # endif
bool ws_want_write; bool ws_want_write;
bool assigned_id;
#else #else
# ifdef WITH_SOCKS # ifdef WITH_SOCKS
char *socks5_host; char *socks5_host;

@ -113,6 +113,156 @@ void connection_check_acl(struct mosquitto_db *db, struct mosquitto *context, st
} }
int connect__on_authorised(struct mosquitto_db *db, struct mosquitto *context)
{
struct mosquitto *found_context;
struct mosquitto__subleaf *leaf;
mosquitto_property *connack_props = NULL;
uint8_t connect_ack = 0;
int i;
int rc;
/* Find if this client already has an entry. This must be done *after* any security checks. */
HASH_FIND(hh_id, db->contexts_by_id, context->id, strlen(context->id), found_context);
if(found_context){
/* Found a matching client */
if(found_context->sock == INVALID_SOCKET){
/* Client is reconnecting after a disconnect */
/* FIXME - does anything need to be done here? */
}else{
/* Client is already connected, disconnect old version. This is
* done in context__cleanup() below. */
if(db->config->connection_messages == true){
log__printf(NULL, MOSQ_LOG_ERR, "Client %s already connected, closing old connection.", context->id);
}
}
if(context->clean_start == false && found_context->session_expiry_interval > 0){
if(context->protocol == mosq_p_mqtt311 || context->protocol == mosq_p_mqtt5){
connect_ack |= 0x01;
}
if(found_context->inflight_msgs || found_context->queued_msgs){
context->inflight_msgs = found_context->inflight_msgs;
context->queued_msgs = found_context->queued_msgs;
found_context->inflight_msgs = NULL;
found_context->queued_msgs = NULL;
db__message_reconnect_reset(db, context);
}
context->subs = found_context->subs;
found_context->subs = NULL;
context->sub_count = found_context->sub_count;
found_context->sub_count = 0;
context->last_mid = found_context->last_mid;
for(i=0; i<context->sub_count; i++){
if(context->subs[i]){
leaf = context->subs[i]->subs;
while(leaf){
if(leaf->context == found_context){
leaf->context = context;
}
leaf = leaf->next;
}
}
}
}
session_expiry__remove(found_context);
found_context->clean_start = true;
found_context->session_expiry_interval = 0;
context__set_state(found_context, mosq_cs_duplicate);
do_disconnect(db, found_context);
}
rc = acl__find_acls(db, context);
if(rc) return rc;
if(db->config->connection_messages == true){
if(context->is_bridge){
if(context->username){
log__printf(NULL, MOSQ_LOG_NOTICE, "New bridge connected from %s as %s (p%d, c%d, k%d, u'%s').",
context->address, context->id, context->protocol, context->clean_start, context->keepalive, context->username);
}else{
log__printf(NULL, MOSQ_LOG_NOTICE, "New bridge connected from %s as %s (p%d, c%d, k%d).",
context->address, context->id, context->protocol, context->clean_start, context->keepalive);
}
}else{
if(context->username){
log__printf(NULL, MOSQ_LOG_NOTICE, "New client connected from %s as %s (p%d, c%d, k%d, u'%s').",
context->address, context->id, context->protocol, context->clean_start, context->keepalive, context->username);
}else{
log__printf(NULL, MOSQ_LOG_NOTICE, "New client connected from %s as %s (p%d, c%d, k%d).",
context->address, context->id, context->protocol, context->clean_start, context->keepalive);
}
}
if(context->will) {
log__printf(NULL, MOSQ_LOG_DEBUG, "Will message specified (%ld bytes) (r%d, q%d).",
(long)context->will->msg.payloadlen,
context->will->msg.retain,
context->will->msg.qos);
log__printf(NULL, MOSQ_LOG_DEBUG, "\t%s", context->will->msg.topic);
} else {
log__printf(NULL, MOSQ_LOG_DEBUG, "No will message specified.");
}
}
context->ping_t = 0;
context->is_dropping = false;
connection_check_acl(db, context, &context->inflight_msgs);
connection_check_acl(db, context, &context->queued_msgs);
HASH_ADD_KEYPTR(hh_id, db->contexts_by_id, context->id, strlen(context->id), context);
#ifdef WITH_PERSISTENCE
if(!context->clean_start){
db->persistence_changes++;
}
#endif
context->maximum_qos = context->listener->maximum_qos;
if(context->protocol == mosq_p_mqtt5){
if(context->maximum_qos != 2){
if(mosquitto_property_add_byte(&connack_props, MQTT_PROP_MAXIMUM_QOS, context->maximum_qos)){
rc = MOSQ_ERR_NOMEM;
goto error;
}
}
if(context->listener->max_topic_alias > 0){
if(mosquitto_property_add_int16(&connack_props, MQTT_PROP_TOPIC_ALIAS_MAXIMUM, context->listener->max_topic_alias)){
rc = MOSQ_ERR_NOMEM;
goto error;
}
}
if(context->keepalive > db->config->max_keepalive){
context->keepalive = db->config->max_keepalive;
if(mosquitto_property_add_int16(&connack_props, MQTT_PROP_SERVER_KEEP_ALIVE, context->keepalive)){
rc = MOSQ_ERR_NOMEM;
goto error;
}
}
if(context->assigned_id){
if(mosquitto_property_add_string(&connack_props, MQTT_PROP_ASSIGNED_CLIENT_IDENTIFIER, context->id)){
rc = MOSQ_ERR_NOMEM;
goto error;
}
}
}
context__set_state(context, mosq_cs_connected);
rc = send__connack(db, context, connect_ack, CONNACK_ACCEPTED, connack_props);
mosquitto_property_free_all(&connack_props);
return rc;
error:
mosquitto_property_free_all(&connack_props);
return rc;
}
static int will__read(struct mosquitto *context, struct mosquitto_message_all **will, uint8_t will_qos, int will_retain) static int will__read(struct mosquitto *context, struct mosquitto_message_all **will, uint8_t will_qos, int will_retain)
{ {
int rc = MOSQ_ERR_SUCCESS; int rc = MOSQ_ERR_SUCCESS;
@ -198,20 +348,16 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context)
char protocol_name[7]; char protocol_name[7];
uint8_t protocol_version; uint8_t protocol_version;
uint8_t connect_flags; uint8_t connect_flags;
uint8_t connect_ack = 0;
char *client_id = NULL; char *client_id = NULL;
struct mosquitto_message_all *will_struct = NULL; struct mosquitto_message_all *will_struct = NULL;
uint8_t will, will_retain, will_qos, clean_start; uint8_t will, will_retain, will_qos, clean_start;
uint8_t username_flag, password_flag; uint8_t username_flag, password_flag;
char *username = NULL, *password = NULL; char *username = NULL, *password = NULL;
int rc; int rc;
struct mosquitto *found_context;
int slen; int slen;
uint16_t slen16; uint16_t slen16;
struct mosquitto__subleaf *leaf;
int i; int i;
mosquitto_property *properties = NULL; mosquitto_property *properties = NULL;
mosquitto_property *connack_props = NULL;
#ifdef WITH_TLS #ifdef WITH_TLS
X509 *client_cert = NULL; X509 *client_cert = NULL;
X509_NAME *name; X509_NAME *name;
@ -293,20 +439,6 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context)
goto handle_connect_error; goto handle_connect_error;
} }
context->maximum_qos = context->listener->maximum_qos;
if(protocol_version == PROTOCOL_VERSION_v5 && context->maximum_qos != 2){
if(mosquitto_property_add_byte(&connack_props, MQTT_PROP_MAXIMUM_QOS, context->maximum_qos)){
rc = MOSQ_ERR_NOMEM;
goto handle_connect_error;
}
}
if(protocol_version == PROTOCOL_VERSION_v5 && context->listener->max_topic_alias > 0){
if(mosquitto_property_add_int16(&connack_props, MQTT_PROP_TOPIC_ALIAS_MAXIMUM, context->listener->max_topic_alias)){
rc = MOSQ_ERR_NOMEM;
goto handle_connect_error;
}
}
if(packet__read_byte(&context->in_packet, &connect_flags)){ if(packet__read_byte(&context->in_packet, &connect_flags)){
rc = 1; rc = 1;
goto handle_connect_error; goto handle_connect_error;
@ -350,13 +482,6 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context)
rc = 1; rc = 1;
goto handle_connect_error; goto handle_connect_error;
} }
if(protocol_version == PROTOCOL_VERSION_v5 && context->keepalive > db->config->max_keepalive){
context->keepalive = db->config->max_keepalive;
if(mosquitto_property_add_int16(&connack_props, MQTT_PROP_SERVER_KEEP_ALIVE, context->keepalive)){
rc = MOSQ_ERR_NOMEM;
goto handle_connect_error;
}
}
if(protocol_version == PROTOCOL_VERSION_v5){ if(protocol_version == PROTOCOL_VERSION_v5){
rc = property__read_all(CMD_CONNECT, &context->in_packet, &properties); rc = property__read_all(CMD_CONNECT, &context->in_packet, &properties);
@ -408,12 +533,7 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context)
rc = MOSQ_ERR_NOMEM; rc = MOSQ_ERR_NOMEM;
goto handle_connect_error; goto handle_connect_error;
} }
if(context->protocol == mosq_p_mqtt5){ context->assigned_id = true;
if(mosquitto_property_add_string(&connack_props, MQTT_PROP_ASSIGNED_CLIENT_IDENTIFIER, client_id)){
rc = MOSQ_ERR_NOMEM;
goto handle_connect_error;
}
}
} }
} }
} }
@ -634,121 +754,11 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context)
goto handle_connect_error; goto handle_connect_error;
} }
} }
/* Find if this client already has an entry. This must be done *after* any security checks. */
HASH_FIND(hh_id, db->contexts_by_id, client_id, strlen(client_id), found_context);
if(found_context){
/* Found a matching client */
if(found_context->sock == INVALID_SOCKET){
/* Client is reconnecting after a disconnect */
/* FIXME - does anything need to be done here? */
}else{
/* Client is already connected, disconnect old version. This is
* done in context__cleanup() below. */
if(db->config->connection_messages == true){
log__printf(NULL, MOSQ_LOG_ERR, "Client %s already connected, closing old connection.", client_id);
}
}
context->clean_start = clean_start; context->clean_start = clean_start;
if(context->clean_start == false && found_context->session_expiry_interval > 0){
if(context->protocol == mosq_p_mqtt311 || context->protocol == mosq_p_mqtt5){
connect_ack |= 0x01;
}
if(found_context->inflight_msgs || found_context->queued_msgs){
context->inflight_msgs = found_context->inflight_msgs;
context->queued_msgs = found_context->queued_msgs;
found_context->inflight_msgs = NULL;
found_context->queued_msgs = NULL;
db__message_reconnect_reset(db, context);
}
context->subs = found_context->subs;
found_context->subs = NULL;
context->sub_count = found_context->sub_count;
found_context->sub_count = 0;
context->last_mid = found_context->last_mid;
for(i=0; i<context->sub_count; i++){
if(context->subs[i]){
leaf = context->subs[i]->subs;
while(leaf){
if(leaf->context == found_context){
leaf->context = context;
}
leaf = leaf->next;
}
}
}
}
session_expiry__remove(found_context);
found_context->clean_start = true;
found_context->session_expiry_interval = 0;
context__set_state(found_context, mosq_cs_duplicate);
do_disconnect(db, found_context);
}
rc = acl__find_acls(db, context);
if(rc) return rc;
if(will_struct){
context->will = will_struct;
will_struct = NULL;
}
if(db->config->connection_messages == true){
if(context->is_bridge){
if(context->username){
log__printf(NULL, MOSQ_LOG_NOTICE, "New bridge connected from %s as %s (c%d, k%d, u'%s').", context->address, client_id, clean_start, context->keepalive, context->username);
}else{
log__printf(NULL, MOSQ_LOG_NOTICE, "New bridge connected from %s as %s (c%d, k%d).", context->address, client_id, clean_start, context->keepalive);
}
}else{
if(context->username){
log__printf(NULL, MOSQ_LOG_NOTICE, "New client connected from %s as %s (c%d, k%d, u'%s').", context->address, client_id, clean_start, context->keepalive, context->username);
}else{
log__printf(NULL, MOSQ_LOG_NOTICE, "New client connected from %s as %s (p%d, c%d, k%d).", context->address, client_id, context->protocol, clean_start, context->keepalive);
}
}
if(context->will) {
log__printf(NULL, MOSQ_LOG_DEBUG, "Will message specified (%ld bytes) (r%d, q%d).",
(long)context->will->msg.payloadlen,
context->will->msg.retain,
context->will->msg.qos);
log__printf(NULL, MOSQ_LOG_DEBUG, "\t%s", context->will->msg.topic);
} else {
log__printf(NULL, MOSQ_LOG_DEBUG, "No will message specified.");
}
}
context->id = client_id; context->id = client_id;
client_id = NULL; context->will = will_struct;
context->clean_start = clean_start;
context->ping_t = 0;
context->is_dropping = false;
if((protocol_version&0x80) == 0x80){
context->is_bridge = true;
}
connection_check_acl(db, context, &context->inflight_msgs);
connection_check_acl(db, context, &context->queued_msgs);
HASH_ADD_KEYPTR(hh_id, db->contexts_by_id, context->id, strlen(context->id), context);
#ifdef WITH_PERSISTENCE return connect__on_authorised(db, context);
if(!clean_start){
db->persistence_changes++;
}
#endif
context__set_state(context, mosq_cs_connected);
rc = send__connack(db, context, connect_ack, CONNACK_ACCEPTED, connack_props);
mosquitto_property_free_all(&connack_props);
return rc;
handle_connect_error: handle_connect_error:
mosquitto__free(client_id); mosquitto__free(client_id);
@ -760,7 +770,6 @@ handle_connect_error:
mosquitto__free(will_struct->msg.topic); mosquitto__free(will_struct->msg.topic);
mosquitto__free(will_struct); mosquitto__free(will_struct);
} }
mosquitto_property_free_all(&connack_props);
#ifdef WITH_TLS #ifdef WITH_TLS
if(client_cert) X509_free(client_cert); if(client_cert) X509_free(client_cert);
#endif #endif

@ -3,6 +3,7 @@
# Test whether a client can connect without an SSL certificate if one is required. # Test whether a client can connect without an SSL certificate if one is required.
from mosq_test_helper import * from mosq_test_helper import *
import errno
if sys.version < '2.7': if sys.version < '2.7':
print("WARNING: SSL not supported on Python 2.6") print("WARNING: SSL not supported on Python 2.6")

@ -35,7 +35,7 @@ def do_test(port, connack_rc, username, password):
sock.close() sock.close()
finally: finally:
if rc: if rc:
exit(rc) raise AssertionError
def username_password_tests(port): def username_password_tests(port):
@ -161,3 +161,5 @@ try:
finally: finally:
os.remove(conf_file) os.remove(conf_file)
os.remove(pw_file) os.remove(pw_file)
sys.exit(0)

Loading…
Cancel
Save