#include "doq-common.hh"
+#if 0
+#define DEBUGLOG_ENABLED
+#define DEBUGLOG(x) std::cerr << x << std::endl;
+#else
+#define DEBUGLOG(x)
+#endif
+
using namespace dnsdist::doq;
class H3Connection
Socket sock(clientState->udpFD);
- PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
DEBUGLOG("Successfully created HTTP/3 connection");
}
+ std::map<std::string, std::string> headers;
while (true) {
quiche_h3_event* ev;
// Processes HTTP/3 data received from the peer
switch (quiche_h3_event_type(ev)) {
case QUICHE_H3_EVENT_HEADERS: {
- std::string path;
int rc = quiche_h3_event_for_each_header(
ev,
[](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
std::string_view key(reinterpret_cast<char*>(name), name_len);
std::string_view content(reinterpret_cast<char*>(value), value_len);
- if (key == ":path") {
- auto pathptr = reinterpret_cast<std::string*>(argp);
- *pathptr = content;
- }
+ auto headersptr = reinterpret_cast<std::map<std::string, std::string>*>(argp);
+ headersptr->emplace(key, content);
return 0;
},
- &path);
- if (rc != 0) {
+ &headers);
+ if (rc != 0 || !headers.count(":method")) {
DEBUGLOG("Failed to process headers");
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState->nonCompliantQueries;
h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to process query headers");
break;
}
- if (path.empty()) {
- DEBUGLOG("Path not found");
- ++dnsdist::metrics::g_stats.nonCompliantQueries;
- ++clientState->nonCompliantQueries;
- ++frontend->d_errorResponses;
- h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found");
- break;
- }
- {
- auto pos = path.find("?dns=");
- if (pos == string::npos) {
- pos = path.find("&dns=");
+ if (headers.at(":method") == "GET") {
+ if (!headers.count(":path") || headers.at(":path").empty()) {
+ DEBUGLOG("Path not found");
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found");
+ break;
}
- if (pos != string::npos) {
- // need to base64url decode this
- string sdns(path.substr(pos + 5));
- boost::replace_all(sdns, "-", "+");
- boost::replace_all(sdns, "_", "/");
- // re-add padding that may have been missing
- switch (sdns.size() % 4) {
- case 2:
- sdns.append(2, '=');
- break;
- case 3:
- sdns.append(1, '=');
- break;
- }
-
- PacketBuffer decoded;
-
- /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */
- const size_t maxAdditionalSizeForEDNS = 35U;
- /* rough estimate so we hopefully don't need a new allocation later */
- /* We reserve at few additional bytes to be able to add EDNS later */
- const size_t estimate = ((sdns.size() * 3) / 4);
- decoded.reserve(estimate + maxAdditionalSizeForEDNS);
- if (B64Decode(sdns, decoded) < 0) {
- DEBUGLOG("Unable to base64 decode()");
- ++dnsdist::metrics::g_stats.nonCompliantQueries;
- ++clientState->nonCompliantQueries;
- ++frontend->d_errorResponses;
- h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL");
- break;
- }
-
- if (decoded.size() < sizeof(dnsheader)) {
- ++dnsdist::metrics::g_stats.nonCompliantQueries;
- ++clientState->nonCompliantQueries;
- ++frontend->d_errorResponses;
- h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
- break;
- }
- DEBUGLOG("Dispatching query");
- doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
- conn->get().d_streamBuffers.erase(streamID);
+ auto pos = headers.at(":path").find("?dns=");
+ if (pos == string::npos) {
+ pos = headers.at(":path").find("&dns=");
}
- else {
+ if (pos == string::npos) {
DEBUGLOG("User error, unable to find the DNS parameter");
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState->nonCompliantQueries;
h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to find the DNS parameter");
break;
}
+ // need to base64url decode this
+ string sdns(headers.at(":path").substr(pos + 5));
+ boost::replace_all(sdns, "-", "+");
+ boost::replace_all(sdns, "_", "/");
+ // re-add padding that may have been missing
+ switch (sdns.size() % 4) {
+ case 2:
+ sdns.append(2, '=');
+ break;
+ case 3:
+ sdns.append(1, '=');
+ break;
+ }
+
+ PacketBuffer decoded;
+ /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */
+ const size_t maxAdditionalSizeForEDNS = 35U;
+ /* rough estimate so we hopefully don't need a new allocation later */
+ /* We reserve at few additional bytes to be able to add EDNS later */
+ const size_t estimate = ((sdns.size() * 3) / 4);
+ decoded.reserve(estimate + maxAdditionalSizeForEDNS);
+ if (B64Decode(sdns, decoded) < 0) {
+ DEBUGLOG("Unable to base64 decode()");
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL");
+ break;
+ }
+ if (decoded.size() < sizeof(dnsheader)) {
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
+ break;
+ }
+ DEBUGLOG("Dispatching GET query");
+ doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
+ conn->get().d_streamBuffers.erase(streamID);
+ }
+ else if (headers.at(":method") == "POST") {
+ if (!quiche_h3_event_headers_has_body(ev)) {
+ DEBUGLOG("Empty POST query");
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Empty POST query");
+ break;
+ }
+ }
+ else {
+ DEBUGLOG("Unsupported HTTP method");
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported HTTP method");
+ break;
}
break;
}
+ case QUICHE_H3_EVENT_DATA: {
+ if (!headers.count("content-type") || headers.at("content-type") != "application/dns-message") {
+ DEBUGLOG("Unsupported content-type");
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported content-type");
+ break;
+ }
+ PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
+ PacketBuffer decoded;
+
+ while (true) {
+ ssize_t len = quiche_h3_recv_body(conn->get().d_http3.get(),
+ conn->get().d_conn.get(), streamID,
+ buffer.data(), buffer.capacity());
- case QUICHE_H3_EVENT_DATA:
+ if (len <= 0) {
+ break;
+ }
+ decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
+ }
+ if (decoded.size() < sizeof(dnsheader)) {
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState->nonCompliantQueries;
+ ++frontend->d_errorResponses;
+ h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query");
+ break;
+ }
+ DEBUGLOG("Dispatching POST query");
+ doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID);
+ conn->get().d_streamBuffers.erase(streamID);
+ }
case QUICHE_H3_EVENT_FINISHED:
case QUICHE_H3_EVENT_RESET:
case QUICHE_H3_EVENT_PRIORITY_UPDATE:
)
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset
-#from aioquic.quic.logger import QuicFileLogger
from aioquic.tls import CipherSuite, SessionTicket
from doqclient import StreamResetError
-#
-#class DnsClientProtocol(QuicConnectionProtocol):
-# def __init__(self, *args, **kwargs):
-# super().__init__(*args, **kwargs)
-# self._ack_waiter: Any = None
-#
-# def pack(self, data):
-# # serialize query
-# data = bytes(data)
-# data = struct.pack("!H", len(data)) + data
-# return data
-#
-# async def query(self, query: dns.message) -> None:
-# data = self.pack(query.to_wire())
-# # send query and wait for answer
-# stream_id = self._quic.get_next_available_stream_id()
-# self._quic.send_stream_data(stream_id, data, end_stream=True)
-# waiter = self._loop.create_future()
-# self._ack_waiter = waiter
-# self.transmit()
-#
-# return await asyncio.shield(waiter)
-#
-# def quic_event_received(self, event: QuicEvent) -> None:
-# if self._ack_waiter is not None:
-# if isinstance(event, StreamDataReceived):
-# length = struct.unpack("!H", bytes(event.data[:2]))[0]
-# answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True)
-#
-# waiter = self._ack_waiter
-# self._ack_waiter = None
-# waiter.set_result(answer)
-# if isinstance(event, StreamReset):
-# waiter = self._ack_waiter
-# self._ack_waiter = None
-# waiter.set_result(event)
-#
-#class BogusDnsClientProtocol(DnsClientProtocol):
-# def pack(self, data):
-# # serialize query
-# data = bytes(data)
-# data = struct.pack("!H", len(data) * 2) + data
-# return data
+
HttpConnection = Union[H0Connection, H3Connection]
class URL:
async def perform_http_request(
client: HttpClient,
url: str,
- data: Optional[str],
+ data: Optional[bytes],
include: bool,
output_dir: Optional[str],
) -> None:
# perform request
start = time.time()
if data is not None:
- data_bytes = data.encode()
http_events = await client.post(
url,
- data=data_bytes,
+ data=data,
headers={
- "content-length": str(len(data_bytes)),
- "content-type": "application/x-www-form-urlencoded",
+ "content-length": str(len(data)),
+ "content-type": "application/dns-message",
},
)
method = "POST"
port: int,
query: dns.message,
timeout: float,
- create_protocol=HttpClient
+ post: bool,
+ create_protocol=HttpClient,
) -> None:
- url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
+ url = baseurl
+ if not post:
+ url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('='))
async with connect(
"127.0.0.1",
port,
answer = await perform_http_request(
client=client,
url=url,
- data=None,
+ data=query.to_wire() if post else None,
include=False,
output_dir=None,
)
except asyncio.TimeoutError as e:
return e
-def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None):
+
+def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
if verify:
configuration.load_verify_locations(verify)
+
result = asyncio.run(
async_h3_query(
configuration=configuration,
port=port,
query=query,
timeout=timeout,
- create_protocol=HttpClient
+ create_protocol=HttpClient,
+ post=post
)
)
def sendQUICQuery(self, query, response=None, useQueue=True, connection=None):
return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection)
+
+class TestDOH3Specifics(DNSDistTest):
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _doqServerPort = pickAvailablePort()
+ _dohBaseURL = ("https://%s:%d/" % (_serverName, _doqServerPort))
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'})
+ """
+ _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey']
+ _verboseMode = True
+
+ def testDOH3Post(self):
+ """
+ QUIC: Simple POST query
+ """
+ name = 'simple.post.doq.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ expectedQuery.id = 0
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, serverName=self._serverName, post=True)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.assertEqual(receivedResponse, response)