IOState tryHandshake() override
{
+ if (!d_feContext) {
+ /* we are a client, nothing to do */
+ return IOState::Done;
+ }
+
int res = SSL_accept(d_conn.get());
if (res == 1) {
return IOState::Done;
void doHandshake() override
{
+ if (!d_feContext) {
+ /* we are a client, nothing to do */
+ return;
+ }
+
int res = 0;
do {
res = SSL_accept(d_conn.get());
}
/* client-side connection */
- GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host)
+ GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
{
unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
#ifdef GNUTLS_NO_SIGNAL
do {
ret = gnutls_handshake(d_conn.get());
if (ret == GNUTLS_E_SUCCESS) {
+ d_handshakeDone = true;
return IOState::Done;
}
else if (ret == GNUTLS_E_AGAIN) {
throw std::runtime_error("Error accepting a new connection");
}
}
- while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+ while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
+
+ d_handshakeDone = true;
}
IOState tryHandshake() override
do {
ret = gnutls_handshake(d_conn.get());
if (ret == GNUTLS_E_SUCCESS) {
+ d_handshakeDone = true;
return IOState::Done;
}
else if (ret == GNUTLS_E_AGAIN) {
- return IOState::NeedRead;
+ int direction = gnutls_record_get_direction(d_conn.get());
+ return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
}
else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
{
+ if (!d_handshakeDone) {
+ auto state = tryHandshake();
+ if (state != IOState::Done) {
+ return state;
+ }
+ }
+
do {
ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
if (res == 0) {
IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead) override
{
+ if (!d_handshakeDone) {
+ auto state = tryHandshake();
+ if (state != IOState::Done) {
+ return state;
+ }
+ }
+
do {
ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
if (res == 0) {
std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
std::string d_host;
+ bool d_client{false};
+ bool d_handshakeDone{false};
};
class GnuTLSIOCtx: public TLSCtx