QuicheConnection d_conn;
QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
+ std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
};
static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
std::shared_ptr<DOH3TCPCrossQuerySender> DOH3CrossProtocolQuery::s_sender = std::make_shared<DOH3TCPCrossQuerySender>();
-static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
+static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response)
+{
+ size_t pos = 0;
+ while (pos < response.size()) {
+ // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
+ auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
+ streamID, &response.at(pos), response.size() - pos, true);
+ if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+ response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
+ return false;
+ }
+ if (res < 0) {
+ // Shutdown with internal error code
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return true;
+ }
+ pos += res;
+ }
+
+ return true;
+}
+
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
{
std::string status = std::to_string(statusCode);
std::string lenStr = std::to_string(len);
.value_len = lenStr.size(),
},
};
- quiche_h3_send_response(conn, quic_conn,
- streamID, headers.data(), headers.size(), len == 0);
+ auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
+ streamID, headers.data(), headers.size(), len == 0);
+ if (returnValue != 0) {
+ /* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return;
+ }
if (len == 0) {
return;
size_t pos = 0;
while (pos < len) {
// send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
- auto res = quiche_h3_send_body(conn, quic_conn,
+ auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
+ if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
+ conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len);
+ return;
+ }
if (res < 0) {
// Shutdown with internal error code
- quiche_conn_stream_shutdown(quic_conn, streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
return;
}
pos += res;
}
}
-static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
-}
-
-static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
-{
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len);
+ h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}
static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
}
}
+static void flushStalledResponses(H3Connection& conn)
+{
+ for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
+ const auto streamID = streamIt->first;
+ auto& response = streamIt->second;
+ if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
+ if (tryWriteResponse(conn, streamID, response)) {
+ streamIt = conn.d_streamOutBuffers.erase(streamIt);
+ continue;
+ }
+ }
+ ++streamIt;
+ }
+}
+
static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
{
auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState.nonCompliantQueries;
++frontend.d_errorResponses;
- h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+ h3_send_response(conn, streamID, 400, msg);
};
// Callback result. Any value other than 0 will interrupt further header processing.
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState.nonCompliantQueries;
++frontend.d_errorResponses;
- h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+ h3_send_response(conn, streamID, 400, msg);
};
- if (headers.at(":method") == "POST") {
- if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
- handleImmediateError("Unsupported content-type");
- return;
- }
- PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
- PacketBuffer decoded;
+ if (headers.at(":method") != "POST") {
+ handleImmediateError("DATA frame for non-POST method");
+ return;
+ }
- while (true) {
- ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
- conn.d_conn.get(), streamID,
- buffer.data(), buffer.capacity());
+ if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
+ handleImmediateError("Unsupported content-type");
+ return;
+ }
- if (len <= 0) {
- break;
- }
- decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
- }
+ PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
+ auto& streamBuffer = conn.d_streamBuffers[streamID];
- if (decoded.size() < sizeof(dnsheader)) {
- handleImmediateError("DoH3 non-compliant query");
- return;
+ while (true) {
+ buffer.resize(std::numeric_limits<uint16_t>::max());
+ ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
+ conn.d_conn.get(), streamID,
+ buffer.data(), buffer.capacity());
+
+ if (len <= 0) {
+ break;
}
- DEBUGLOG("Dispatching POST query");
- doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID);
+ buffer.resize(static_cast<size_t>(len));
+ streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end());
+ }
+
+ if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) {
+ return;
+ }
+
+ if (streamBuffer.size() < sizeof(dnsheader)) {
conn.d_streamBuffers.erase(streamID);
+ handleImmediateError("DoH3 non-compliant query");
+ return;
}
+
+ DEBUGLOG("Dispatching POST query");
+ doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+ conn.d_streamBuffers.erase(streamID);
}
static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID)
conn = frontend->d_server_config->d_connections.erase(conn);
}
else {
+ flushStalledResponses(conn->second);
++conn;
}
}
ComboAddress d_peer;
QuicheConnection d_conn;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
+ std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
};
static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();
-static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
+static bool tryWriteResponse(Connection& conn, const uint64_t streamID, PacketBuffer& response)
+{
+ size_t pos = 0;
+ while (pos < response.size()) {
+ auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
+ if (res == QUICHE_ERR_DONE) {
+ response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
+ return false;
+ }
+ if (res < 0) {
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return true;
+ }
+ pos += res;
+ }
+
+ return true;
+}
+
+static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, PacketBuffer& response)
{
if (response.empty()) {
++frontend.d_errorResponses;
++frontend.d_validResponses;
auto responseSize = static_cast<uint16_t>(response.size());
const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
- size_t pos = 0;
- while (pos < sizeBytes.size()) {
- auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &sizeBytes.at(pos), sizeBytes.size() - pos, false);
- if (res < 0) {
- quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
- return;
- }
- pos += res;
- }
-
- pos = 0;
- while (pos < response.size()) {
- // stream_send sets fin to false itself when the capacity of the stream is less than the desired writing length
- auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
- if (res < 0) {
- quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
- return;
- }
- pos += res;
+ response.insert(response.begin(), sizeBytes.begin(), sizeBytes.end());
+ if (!tryWriteResponse(conn, streamID, response)) {
+ conn.d_streamOutBuffers[streamID] = std::move(response);
}
}
}
}
+static void flushStalledResponses(Connection& conn)
+{
+ for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
+ const auto& streamID = streamIt->first;
+ auto& response = streamIt->second;
+ if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
+ if (tryWriteResponse(conn, streamID, response)) {
+ streamIt = conn.d_streamOutBuffers.erase(streamIt);
+ continue;
+ }
+ }
+ ++streamIt;
+ }
+}
+
// this is the entrypoint from dnsdist.cc
void doqThread(ClientState* clientState)
{
conn = frontend->d_server_config->d_connections.erase(conn);
}
else {
+ flushStalledResponses(conn->second);
++conn;
}
}