diff --git a/lib/mosquitto.c b/lib/mosquitto.c index 3dfe4f1d..c449cc98 100644 --- a/lib/mosquitto.c +++ b/lib/mosquitto.c @@ -200,6 +200,7 @@ int mosquitto_reinitialise(struct mosquitto *mosq, const char *id, bool clean_se mosq->ssl = NULL; mosq->tls_cert_reqs = SSL_VERIFY_PEER; mosq->tls_insecure = false; + mosq->want_write = false; #endif #ifdef WITH_THREADING pthread_mutex_init(&mosq->callback_mutex, NULL); @@ -839,12 +840,21 @@ int mosquitto_loop(struct mosquitto *mosq, int timeout, int max_packets) pthread_mutex_lock(&mosq->out_packet_mutex); if(mosq->out_packet || mosq->current_out_packet){ FD_SET(mosq->sock, &writefds); + } #ifdef WITH_TLS - }else if(mosq->ssl && mosq->want_write){ - FD_SET(mosq->sock, &writefds); - mosq->want_write = false; -#endif + if(mosq->ssl){ + if(mosq->want_write){ + FD_SET(mosq->sock, &writefds); + mosq->want_write = false; + }else if(mosq->want_connect){ + /* Remove possible FD_SET from above, we don't want to check + * for writing if we are still connecting, unless want_write is + * definitely set. The presence of outgoing packets does not + * matter yet. */ + FD_CLR(mosq->sock, &writefds); + } } +#endif pthread_mutex_unlock(&mosq->out_packet_mutex); pthread_mutex_unlock(&mosq->current_out_packet_mutex); }else{ @@ -908,9 +918,17 @@ int mosquitto_loop(struct mosquitto *mosq, int timeout, int max_packets) }else{ if(mosq->sock != INVALID_SOCKET){ if(FD_ISSET(mosq->sock, &readfds)){ - rc = mosquitto_loop_read(mosq, max_packets); - if(rc || mosq->sock == INVALID_SOCKET){ - return rc; +#ifdef WITH_TLS + if(mosq->want_connect){ + rc = mosquitto__socket_connect_tls(mosq); + if(rc) return rc; + }else +#endif + { + rc = mosquitto_loop_read(mosq, max_packets); + if(rc || mosq->sock == INVALID_SOCKET){ + return rc; + } } } if(mosq->sockpairR >= 0 && FD_ISSET(mosq->sockpairR, &readfds)){ @@ -926,9 +944,17 @@ int mosquitto_loop(struct mosquitto *mosq, int timeout, int max_packets) FD_SET(mosq->sock, &writefds); } if(FD_ISSET(mosq->sock, &writefds)){ - rc = mosquitto_loop_write(mosq, max_packets); - if(rc || mosq->sock == INVALID_SOCKET){ - return rc; +#ifdef WITH_TLS + if(mosq->want_connect){ + rc = mosquitto__socket_connect_tls(mosq); + if(rc) return rc; + }else +#endif + { + rc = mosquitto_loop_write(mosq, max_packets); + if(rc || mosq->sock == INVALID_SOCKET){ + return rc; + } } } } diff --git a/lib/mosquitto_internal.h b/lib/mosquitto_internal.h index f4fcde8a..dff9c333 100644 --- a/lib/mosquitto_internal.h +++ b/lib/mosquitto_internal.h @@ -180,6 +180,7 @@ struct mosquitto { bool tls_insecure; #endif bool want_write; + bool want_connect; #if defined(WITH_THREADING) && !defined(WITH_BROKER) pthread_mutex_t callback_mutex; pthread_mutex_t log_callback_mutex; diff --git a/lib/net_mosq.c b/lib/net_mosq.c index 67c22169..5246db2a 100644 --- a/lib/net_mosq.c +++ b/lib/net_mosq.c @@ -372,6 +372,32 @@ int _mosquitto_try_connect(struct mosquitto *mosq, const char *host, uint16_t po return rc; } +#ifdef WITH_TLS +int mosquitto__socket_connect_tls(struct mosquitto *mosq) +{ + int ret; + + ret = SSL_connect(mosq->ssl); + if(ret != 1){ + ret = SSL_get_error(mosq->ssl, ret); + if(ret == SSL_ERROR_WANT_READ){ + mosq->want_connect = true; + /* We always try to read anyway */ + }else if(ret == SSL_ERROR_WANT_WRITE){ + mosq->want_write = true; + mosq->want_connect = true; + }else{ + COMPAT_CLOSE(mosq->sock); + mosq->sock = INVALID_SOCKET; + return MOSQ_ERR_TLS; + } + }else{ + mosq->want_connect = false; + } + return MOSQ_ERR_SUCCESS; +} +#endif + /* Create a socket and connect it to 'ip' on port 'port'. * Returns -1 on failure (ip is NULL, socket creation/connection error) * Returns sock number on success. @@ -519,18 +545,11 @@ int _mosquitto_socket_connect(struct mosquitto *mosq, const char *host, uint16_t } SSL_set_bio(mosq->ssl, bio, bio); - ret = SSL_connect(mosq->ssl); - if(ret != 1){ - ret = SSL_get_error(mosq->ssl, ret); - if(ret == SSL_ERROR_WANT_READ){ - /* We always try to read anyway */ - }else if(ret == SSL_ERROR_WANT_WRITE){ - mosq->want_write = true; - }else{ - COMPAT_CLOSE(sock); - return MOSQ_ERR_TLS; - } + mosq->sock = sock; + if(mosquitto__socket_connect_tls(mosq)){ + return MOSQ_ERR_TLS; } + } #endif diff --git a/lib/net_mosq.h b/lib/net_mosq.h index af207dfe..e1162d0c 100644 --- a/lib/net_mosq.h +++ b/lib/net_mosq.h @@ -86,6 +86,7 @@ int _mosquitto_packet_read(struct mosquitto *mosq); #ifdef WITH_TLS int _mosquitto_socket_apply_tls(struct mosquitto *mosq); +int mosquitto__socket_connect_tls(struct mosquitto *mosq); #endif #endif