std::shared_ptr<DNSAction> d_action;
};
-#ifdef HAVE_DNS_OVER_HTTPS
+#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
class HTTPStatusAction : public DNSAction
{
public:
DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
{
- if (!dnsquestion->ids.du) {
- return Action::None;
+#if defined(HAVE_DNS_OVER_HTTPS)
+ if (dnsquestion->ids.du) {
+ dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ return Action::HeaderModify;
}
-
- dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
- header.qr = true; // for good measure
- setResponseHeadersFromConfig(header, d_responseConfig);
- return true;
- });
- return Action::HeaderModify;
+#endif /* HAVE_DNS_OVER_HTTPS */
+#if defined(HAVE_DNS_OVER_HTTP3)
+ if (dnsquestion->ids.doh3u) {
+ dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
+ return Action::HeaderModify;
+ }
+#endif /* HAVE_DNS_OVER_HTTP3 */
+ return Action::None;
}
[[nodiscard]] std::string toString() const override
std::string d_contentType;
int d_code;
};
-#endif /* HAVE_DNS_OVER_HTTPS */
+#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */
#if defined(HAVE_LMDB) || defined(HAVE_CDB)
class KeyValueStoreLookupAction : public DNSAction
#endif /* HAVE_NET_SNMP */
});
-#ifdef HAVE_DNS_OVER_HTTPS
+#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPPath", [](const DNSQuestion& dnsQuestion) {
if (dnsQuestion.ids.du) {
return dnsQuestion.ids.du->getHTTPPath();
});
luaCtx.registerFunction<void (DNSQuestion::*)(uint64_t statusCode, const std::string& body, const boost::optional<std::string> contentType)>("setHTTPResponse", [](DNSQuestion& dnsQuestion, uint64_t statusCode, const std::string& body, const boost::optional<std::string>& contentType) {
- if (dnsQuestion.ids.du == nullptr) {
+ if (dnsQuestion.ids.du == nullptr && dnsQuestion.ids.doh3u == nullptr) {
return;
}
checkParameterBound("DNSQuestion::setHTTPResponse", statusCode, std::numeric_limits<uint16_t>::max());
PacketBuffer vect(body.begin(), body.end());
- dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
+ if (dnsQuestion.ids.du) {
+ dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
+ }
+ else {
+ dnsQuestion.ids.doh3u->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
+ }
});
-#endif /* HAVE_DNS_OVER_HTTPS */
+#endif /* HAVE_DNS_OVER_HTTPS HAVE_DNS_OVER_HTTP3 */
luaCtx.registerFunction<bool (DNSQuestion::*)(bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum)>("setNegativeAndAdditionalSOA", [](DNSQuestion& dnsQuestion, bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum) {
checkParameterBound("setNegativeAndAdditionalSOA", ttl, std::numeric_limits<uint32_t>::max());
void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType)
{
- if (dq->dq->ids.du == nullptr) {
- return;
+#if defined(HAVE_DNS_OVER_HTTPS)
+ if (dq->dq->ids.du) {
+ PacketBuffer bodyVect(body, body + bodyLen);
+ dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
+ header.qr = true;
+ return true;
+ });
+ }
+#endif
+#if defined(HAVE_DNS_OVER_HTTP3)
+ if (dq->dq->ids.doh3u) {
+ PacketBuffer bodyVect(body, body + bodyLen);
+ dq->dq->ids.doh3u->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
+ header.qr = true;
+ return true;
+ });
}
-
-#ifdef HAVE_DNS_OVER_HTTPS
- PacketBuffer bodyVect(body, body + bodyLen);
- dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
- dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
- header.qr = true;
- return true;
- });
#endif
}
return true;
}
-static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
+static void addHeaderToList(std::vector<quiche_h3_header>& headers, const char* name, size_t nameLen, const char* value, size_t valueLen)
{
- std::string status = std::to_string(statusCode);
- std::string lenStr = std::to_string(len);
- std::array<quiche_h3_header, 3> headers{
- (quiche_h3_header){
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .name = reinterpret_cast<const uint8_t*>(":status"),
- .name_len = sizeof(":status") - 1,
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .value = reinterpret_cast<const uint8_t*>(status.data()),
- .value_len = status.size(),
- },
- (quiche_h3_header){
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .name = reinterpret_cast<const uint8_t*>("content-length"),
- .name_len = sizeof("content-length") - 1,
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .value = reinterpret_cast<const uint8_t*>(lenStr.data()),
- .value_len = lenStr.size(),
- },
- (quiche_h3_header){
+ headers.emplace_back((quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .name = reinterpret_cast<const uint8_t*>("content-type"),
- .name_len = sizeof("content-type") - 1,
+ .name = reinterpret_cast<const uint8_t*>(name),
+ .name_len = nameLen,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
- .value = reinterpret_cast<const uint8_t*>("application/dns-message"),
- .value_len = sizeof("application/dns-message") - 1,
- },
- };
+ .value = reinterpret_cast<const uint8_t*>(value),
+ .value_len = valueLen,
+ });
+}
+
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len, const std::string& contentType = {})
+{
+ std::string status = std::to_string(statusCode);
+ PacketBuffer location;
+ PacketBuffer responseBody;
+ std::vector<quiche_h3_header> headers;
+ headers.reserve(3);
+ addHeaderToList(headers, ":status", sizeof(":status") - 1, status.data(), status.size());
+
+ if (statusCode >= 300 && statusCode < 400) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ addHeaderToList(headers, "location", sizeof("location") - 1, reinterpret_cast<const char*>(body), len);
+ static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
+ static const std::string s_redirectEnd{"\">here</A>"};
+ static const std::string s_redirectContentType("text/html; charset=utf-8");
+ addHeaderToList(headers, "content-type", sizeof("content-type") - 1, s_redirectContentType.data(), s_redirectContentType.size());
+ responseBody.reserve(s_redirectStart.size() + len + s_redirectEnd.size());
+ responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
+ responseBody.insert(responseBody.end(), body, body + len);
+ responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
+ body = responseBody.data();
+ len = responseBody.size();
+ }
+ else if (len > 0 && (statusCode == 200U || !contentType.empty())) {
+ // do not include content-type header info if there is no content
+ addHeaderToList(headers, "content-type", sizeof("content-type") - 1, contentType.empty() ? "application/dns-message" : contentType.data(), contentType.empty() ? sizeof("application/dns-message") - 1 : contentType.size());
+ }
+
+ const std::string lenStr = std::to_string(len);
+ addHeaderToList(headers, "content-length", sizeof("content-length") - 1, lenStr.data(), lenStr.size());
+
auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
streamID, headers.data(),
- // do not include content-type header info if there is no content
- (len > 0 && statusCode == 200U ? headers.size() : headers.size() - 1),
+ headers.size(),
len == 0);
if (returnValue != 0) {
/* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
}
}
-static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content = {})
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}
-static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
+static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response, const std::string& contentType)
{
if (statusCode == 200) {
++frontend.d_validResponses;
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
}
else {
- h3_send_response(conn, streamID, statusCode, &response.at(0), response.size());
+ h3_send_response(conn, streamID, statusCode, &response.at(0), response.size(), contentType);
}
}
const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
DEBUGLOG("handleImmediateResponse() reason=" << reason);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
- handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
+ handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
unit->ids.doh3u.reset();
};
auto unit = std::move(*tmp);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
if (conn) {
- handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
+ handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
}
}
catch (const std::exception& e) {
return headers;
}
+void DOH3Unit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType)
+{
+ status_code = statusCode;
+ response = std::move(body);
+ d_contentTypeOut = contentType;
+}
+
#else /* HAVE_DNS_OVER_HTTP3 */
std::string DOH3Unit::getHTTPPath() const
return headers;
}
+void DOH3Unit::setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&)
+{
+}
+
#endif /* HAVE_DNS_OVER_HTTP3 */
#include <unordered_map>
#include "config.h"
+#include "noinitvector.hh"
#ifdef HAVE_DNS_OVER_HTTP3
#include "channel.hh"
#include "iputils.hh"
#include "libssl.hh"
-#include "noinitvector.hh"
#include "stat_t.hh"
#include "dnsdist-idstate.hh"
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
+ void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "");
InternalQueryState ids;
PacketBuffer query;
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
+ std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
+ void setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&);
};
struct DOH3Frontend
return (receivedQuery, message)
@classmethod
- def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None):
+ def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None, rawResponse=False):
if response:
if toQueue:
else:
cls._toResponderQueue.put(response, True, timeout)
- message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders)
+ if rawResponse:
+ return doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
+
+ message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)
receivedQuery = None
elapsed = time.time() - start
result = bytes()
+ headers = {}
for http_event in http_events:
if isinstance(http_event, DataReceived):
result += http_event.data
if isinstance(http_event, StreamReset):
result = http_event
- return result
+ if isinstance(http_event, HeadersReceived):
+ for k, v in http_event.headers:
+ headers[k] = v
+ return (result, headers)
async def async_h3_query(
return answer
except asyncio.TimeoutError as e:
- return e
+ return (e,{})
-def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None):
+def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
if verify:
configuration.load_verify_locations(verify)
- result = asyncio.run(
+ (result, headers) = asyncio.run(
async_h3_query(
configuration=configuration,
baseurl=baseurl,
raise StreamResetError(result.error_code)
if (isinstance(result, asyncio.TimeoutError)):
raise TimeoutError()
+ if raw_response:
+ return (result, headers)
return dns.message.from_wire(result)
addAction(HTTPHeaderRule("X-PowerDNS", "^[a]{5}$"), SpoofAction("2.3.4.5"))
addAction(HTTPPathRule("/PowerDNS"), SpoofAction("3.4.5.6"))
addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9"))
+ addAction("http-status-action.doh3.tests.powerdns.com.", HTTPStatusAction(200, "Plaintext answer", "text/plain"))
+ addAction("http-status-action-redirect.doh3.tests.powerdns.com.", HTTPStatusAction(307, "https://doh.powerdns.org"))
addAction("no-backend.doq.tests.powerdns.com.", PoolAction('this-pool-has-no-backend'))
function dohHandler(dq)
end
end
if foundct then
- return DNSAction.Spoof, "10.11.12.13"
+ dq:setHTTPResponse(200, 'It works!', 'text/plain')
+ dq.dh:setQR(true)
+ return DNSAction.HeaderModify
end
end
return DNSAction.None
self.checkQueryNoEDNS(expectedQuery, receivedQuery)
self.assertEqual(response, receivedResponse)
+ def testHTTPStatusAction200(self):
+ """
+ DOH3: HTTPStatusAction 200 OK
+ """
+ name = 'http-status-action.doh3.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+
+ (receivedResponse, receivedHeaders) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True, rawResponse=True)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(receivedResponse, b'Plaintext answer')
+ self.assertIn(b':status', receivedHeaders)
+ self.assertEqual(receivedHeaders[b':status'], b'200')
+ self.assertIn(b'content-type', receivedHeaders)
+ self.assertEqual(receivedHeaders[b'content-type'], b'text/plain')
+
+ def testHTTPStatusAction307(self):
+ """
+ DOH3: HTTPStatusAction 307
+ """
+ name = 'http-status-action-redirect.doh3.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+
+ (receivedResponse, receivedHeaders) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True, rawResponse=True)
+ self.assertTrue(receivedResponse)
+ self.assertIn(b':status', receivedHeaders)
+ self.assertEqual(receivedHeaders[b':status'], b'307')
+ self.assertIn(b'location', receivedHeaders)
+ self.assertEqual(receivedHeaders[b'location'], b'https://doh.powerdns.org')
+
def testHTTPLuaBindings(self):
"""
DOH3: Lua HTTP bindings
query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
query.id = 0
- (_, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True)
+ (receivedResponse, receivedHeaders) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, caFile=self._caCert, useQueue=False, post=True, rawResponse=True)
self.assertTrue(receivedResponse)
+ self.assertEqual(receivedResponse, b'It works!')
+ self.assertIn(b':status', receivedHeaders)
+ self.assertEqual(receivedHeaders[b':status'], b'200')
+ self.assertIn(b'content-type', receivedHeaders)
+ self.assertEqual(receivedHeaders[b'content-type'], b'text/plain')
class TestDOH3ACL(QUICACLTests, DNSDistTest):
_serverKey = 'server.key'