ComboAddress d_peer;
QuicheConnection d_conn;
+ std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
};
static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description);
std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();
+/* from rfc9250 section-4.3 */
+enum class DOQ_Error_Codes : uint64_t {
+ DOQ_NO_ERROR = 0,
+ DOQ_INTERNAL_ERROR = 1,
+ DOQ_PROTOCOL_ERROR = 2,
+ DOQ_REQUEST_CANCELLED = 3,
+ DOQ_EXCESSIVE_LOAD = 4,
+ DOQ_UNSPECIFIED_ERROR = 5
+};
+
static void handleResponse(DOQFrontend& df, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
{
if (response.size() == 0) {
- quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, 0x5);
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
+ return;
}
- else {
- uint16_t responseSize = static_cast<uint16_t>(response.size());
- const uint8_t sizeBytes[] = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
- auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes, sizeof(sizeBytes), false);
- if (res == sizeof(sizeBytes)) {
- res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data(), response.size(), true);
+
+ uint16_t responseSize = static_cast<uint16_t>(response.size());
+ const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
+ size_t pos = 0;
+ do {
+ auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes.data() + pos, sizeBytes.size() - pos, false);
+ if (res < 0) {
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return;
}
+ pos += res;
}
+ while (pos < sizeBytes.size());
+
+ pos = 0;
+ do {
+ auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data() + pos, response.size() - pos, true);
+ if (res < 0) {
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return;
+ }
+ pos += res;
+ }
+ while (pos < response.size());
}
static void fillRandom(PacketBuffer& buffer, size_t size)
Socket sock(cs->udpFD);
- PacketBuffer buffer(std::numeric_limits<unsigned short>::max());
+ PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
uint64_t streamID = 0;
while (quiche_stream_iter_next(readable.get(), &streamID)) {
+ auto& streamBuffer = conn->get().d_streamBuffers[streamID];
+ auto existingLength = streamBuffer.size();
bool fin = false;
- buffer.resize(std::numeric_limits<unsigned short>::max());
+ streamBuffer.resize(existingLength + 512);
auto received = quiche_conn_stream_recv(conn->get().d_conn.get(), streamID,
- buffer.data(), buffer.size(),
+ &streamBuffer.at(existingLength), 512,
&fin);
- if (received < 2) {
- break;
- }
- buffer.resize(received);
-
+ streamBuffer.resize(existingLength + received);
if (fin) {
- // we skip message length, should we verify ?
- buffer.erase(buffer.begin(), buffer.begin() + 2);
- if (buffer.size() >= sizeof(dnsheader)) {
- doq_dispatch_query(*(frontend->d_server_config), std::move(buffer), cs->local, client, serverConnID, streamID);
+ if (streamBuffer.size() < (sizeof(dnsheader) + sizeof(uint16_t))) {
+ quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+ break;
+ }
+ uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
+ streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
+ if (payloadLength != streamBuffer.size()) {
+ quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+ break;
}
+ doq_dispatch_query(*(frontend->d_server_config), std::move(streamBuffer), cs->local, client, serverConnID, streamID);
+ conn->get().d_streamBuffers.erase(streamID);
}
}
}
else {
DEBUGLOG("Connection not established");
}
- // }
}
if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {