diff --git a/src/handle_connect.c b/src/handle_connect.c index 09e29e37..d47e0011 100644 --- a/src/handle_connect.c +++ b/src/handle_connect.c @@ -157,6 +157,84 @@ void connection_check_acl(struct mosquitto_db *db, struct mosquitto *context, st } } + +static int will__read(struct mosquitto *context, struct mosquitto_message_all **will, uint8_t will_qos, int will_retain) +{ + int rc = MOSQ_ERR_SUCCESS; + int slen; + struct mosquitto_message_all *will_struct = NULL; + char *will_topic_mount = NULL; + uint16_t payloadlen; + + will_struct = mosquitto__calloc(1, sizeof(struct mosquitto_message_all)); + if(!will_struct){ + rc = MOSQ_ERR_NOMEM; + goto error_cleanup; + } + if(context->protocol == PROTOCOL_VERSION_v5){ + rc = property__read_all(CMD_WILL, &context->in_packet, &will_struct->properties); + if(rc) goto error_cleanup; + + mosquitto_property_read_int32(will_struct->properties, MQTT_PROP_MESSAGE_EXPIRY_INTERVAL, &will_struct->expiry_interval, false); + mosquitto_property_free_all(&will_struct->properties); /* FIXME - TEMPORARY UNTIL PROPERTIES PROCESSED */ + } + rc = packet__read_string(&context->in_packet, &will_struct->msg.topic, &slen); + if(rc) goto error_cleanup; + if(!slen){ + rc = MOSQ_ERR_PROTOCOL; + goto error_cleanup; + } + + if(context->listener->mount_point){ + slen = strlen(context->listener->mount_point) + strlen(will_struct->msg.topic) + 1; + will_topic_mount = mosquitto__malloc(slen+1); + if(!will_topic_mount){ + rc = MOSQ_ERR_NOMEM; + goto error_cleanup; + } + + snprintf(will_topic_mount, slen, "%s%s", context->listener->mount_point, will_struct->msg.topic); + will_topic_mount[slen] = '\0'; + + mosquitto__free(will_struct->msg.topic); + will_struct->msg.topic = will_topic_mount; + } + + rc = mosquitto_pub_topic_check(will_struct->msg.topic); + if(rc) goto error_cleanup; + + rc = packet__read_uint16(&context->in_packet, &payloadlen); + if(rc) goto error_cleanup; + + will_struct->msg.payloadlen = payloadlen; + if(will_struct->msg.payloadlen > 0){ + will_struct->msg.payload = mosquitto__malloc(will_struct->msg.payloadlen); + if(!will_struct->msg.payload){ + rc = MOSQ_ERR_NOMEM; + goto error_cleanup; + } + + rc = packet__read_bytes(&context->in_packet, will_struct->msg.payload, will_struct->msg.payloadlen); + if(rc) goto error_cleanup; + } + + will_struct->msg.qos = will_qos; + will_struct->msg.retain = will_retain; + + *will = will_struct; + return MOSQ_ERR_SUCCESS; + +error_cleanup: + if(will_struct){ + mosquitto__free(will_struct->msg.topic); + mosquitto__free(will_struct->msg.payload); + mosquitto__free(will_struct); + } + return rc; +} + + + int handle__connect(struct mosquitto_db *db, struct mosquitto *context) { char protocol_name[7]; @@ -164,10 +242,6 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context) uint8_t connect_flags; uint8_t connect_ack = 0; char *client_id = NULL; - char *will_payload = NULL, *will_topic = NULL; - char *will_topic_mount; - uint16_t will_payloadlen; - uint32_t will_expiry_interval = 0; struct mosquitto_message_all *will_struct = NULL; uint8_t will, will_retain, will_qos, clean_start; uint8_t username_flag, password_flag; @@ -375,62 +449,8 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context) } if(will){ - will_struct = mosquitto__calloc(1, sizeof(struct mosquitto_message_all)); - if(!will_struct){ - rc = MOSQ_ERR_NOMEM; - goto handle_connect_error; - } - if(protocol_version == PROTOCOL_VERSION_v5){ - rc = property__read_all(CMD_WILL, &context->in_packet, &will_struct->properties); - if(rc) return rc; - mosquitto_property_read_int32(properties, MQTT_PROP_MESSAGE_EXPIRY_INTERVAL, &will_expiry_interval, false); - mosquitto_property_free_all(&properties); /* FIXME - TEMPORARY UNTIL PROPERTIES PROCESSED */ - } - if(packet__read_string(&context->in_packet, &will_topic, &slen)){ - rc = 1; - goto handle_connect_error; - } - if(!slen){ - rc = 1; - goto handle_connect_error; - } - - if(context->listener->mount_point){ - slen = strlen(context->listener->mount_point) + strlen(will_topic) + 1; - will_topic_mount = mosquitto__malloc(slen+1); - if(!will_topic_mount){ - rc = MOSQ_ERR_NOMEM; - goto handle_connect_error; - } - snprintf(will_topic_mount, slen, "%s%s", context->listener->mount_point, will_topic); - will_topic_mount[slen] = '\0'; - - mosquitto__free(will_topic); - will_topic = will_topic_mount; - } - - if(mosquitto_pub_topic_check(will_topic)){ - rc = 1; - goto handle_connect_error; - } - - if(packet__read_uint16(&context->in_packet, &will_payloadlen)){ - rc = 1; - goto handle_connect_error; - } - if(will_payloadlen > 0){ - will_payload = mosquitto__malloc(will_payloadlen); - if(!will_payload){ - rc = 1; - goto handle_connect_error; - } - - rc = packet__read_bytes(&context->in_packet, will_payload, will_payloadlen); - if(rc){ - rc = 1; - goto handle_connect_error; - } - } + rc = will__read(context, &will_struct, will_qos, will_retain); + if(rc) goto handle_connect_error; }else{ if(context->protocol == mosq_p_mqtt311 || context->protocol == mosq_p_mqtt5){ if(will_qos != 0 || will_retain != 0){ @@ -697,17 +717,7 @@ int handle__connect(struct mosquitto_db *db, struct mosquitto *context) if(will_struct){ context->will = will_struct; - context->will->msg.topic = will_topic; - if(will_payload){ - context->will->msg.payload = will_payload; - context->will->msg.payloadlen = will_payloadlen; - }else{ - context->will->msg.payload = NULL; - context->will->msg.payloadlen = 0; - } - context->will->msg.qos = will_qos; - context->will->msg.retain = will_retain; - context->will->expiry_interval = will_expiry_interval; + will_struct = NULL; } if(db->config->connection_messages == true){ @@ -765,13 +775,13 @@ handle_connect_error: mosquitto__free(client_id); mosquitto__free(username); mosquitto__free(password); - mosquitto__free(will_payload); - mosquitto__free(will_topic); if(will_struct){ mosquitto_property_free_all(&will_struct->properties); + mosquitto__free(will_struct->msg.payload); + mosquitto__free(will_struct->msg.topic); + mosquitto__free(will_struct); } mosquitto_property_free_all(&connack_props); - mosquitto__free(will_struct); #ifdef WITH_TLS if(client_cert) X509_free(client_cert); #endif