return 0;
}
-static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
{
for (auto it = options.begin(); it != options.end(); ) {
- if (it->first == EDNSOptionCode::ECS) {
- ecsAdded = false;
+ if (it->first == optionCode) {
+ optionAdded = false;
if (!overrideExisting) {
return false;
}
}
- options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
+ options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
return true;
}
-static bool slowRewriteQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
- ecsAdded = false;
+ optionAdded = false;
ednsAdded = true;
if (ntohs(dh->qdcount) == 0) {
}
if (ntohs(dh->ancount) == 0 && ntohs(dh->nscount) == 0 && ntohs(dh->arcount) == 0) {
- throw std::runtime_error("slowRewriteQueryWithRecords() should not be called for queries that have no records");
+ throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + " should not be called for queries that have no records");
}
PacketReader pr(pdns_string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
- /* addOrReplaceECSOption will set it to false if there is already an existing option */
- ecsAdded = true;
- addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
+ /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
+ optionAdded = true;
+ addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, newOptionContent);
pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
}
}
if (ednsAdded) {
- pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
- ecsAdded = true;
+ pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
+ optionAdded = true;
}
pw.commit();
PacketBuffer newContent;
newContent.reserve(packet.size());
- if (!slowRewriteQueryWithRecords(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+ if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, newECSOption)) {
ednsAdded = false;
ecsAdded = false;
return false;
extern uint16_t g_PayloadSizeSelfGenAnswers;
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent);
+bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent);
int locateEDNSOptRR(const PacketBuffer & packet, uint16_t * optStart, size_t * optLen, bool * last);
bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK);
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength);
public:
// this action does not stop the processing
SetMacAddrAction(uint16_t code) : d_code(code)
- {}
- DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
- if (dq->getHeader()->arcount) {
- return Action::None;
- }
+ }
+ DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+ {
std::string mac = getMACAddress(*dq->remote);
if (mac.empty()) {
return Action::None;
std::string optRData;
generateEDNSOption(d_code, mac, optRData);
+ if (dq->getHeader()->arcount) {
+ bool ednsAdded = false;
+ bool optionAdded = false;
+ PacketBuffer newContent;
+ newContent.reserve(dq->getData().size());
+
+ if (!slowRewriteEDNSOptionInQueryWithRecords(dq->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) {
+ return Action::None;
+ }
+
+ if (newContent.size() > dq->getMaximumSize()) {
+ return Action::None;
+ }
+
+ dq->getMutableData() = std::move(newContent);
+ if (!dq->ednsAdded && ednsAdded) {
+ dq->ednsAdded = true;
+ }
+
+ return Action::None;
+ }
+
auto& data = dq->getMutableData();
if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
dq->getHeader()->arcount = htons(1);
}
std::string toString() const override
{
- return "add EDNS MAC (code="+std::to_string(d_code)+")";
+ return "add EDNS MAC (code=" + std::to_string(d_code) + ")";
}
private:
uint16_t d_code{3};
public:
// this action does not stop the processing
SetEDNSOptionAction(uint16_t code, const std::string& data) : d_code(code), d_data(data)
- {}
+ {
+ }
DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
+ std::string optRData;
+ generateEDNSOption(d_code, d_data, optRData);
+
if (dq->getHeader()->arcount) {
+ bool ednsAdded = false;
+ bool optionAdded = false;
+ PacketBuffer newContent;
+ newContent.reserve(dq->getData().size());
+
+ if (!slowRewriteEDNSOptionInQueryWithRecords(dq->getData(), newContent, ednsAdded, d_code, optionAdded, true, optRData)) {
+ return Action::None;
+ }
+
+ if (newContent.size() > dq->getMaximumSize()) {
+ return Action::None;
+ }
+
+ dq->getMutableData() = std::move(newContent);
+ if (!dq->ednsAdded && ednsAdded) {
+ dq->ednsAdded = true;
+ }
+
return Action::None;
}
- std::string optRData;
- generateEDNSOption(d_code, d_data, optRData);
-
auto& data = dq->getMutableData();
if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
dq->getHeader()->arcount = htons(1);
std::string toString() const override
{
- return "add EDNS Option (code="+std::to_string(d_code)+")";
+ return "add EDNS Option (code=" + std::to_string(d_code) + ")";
}
private:
.. versionadded:: 1.7.0
- Add arbitrary EDNS option and data to the query.
+ Add arbitrary EDNS option and data to the query. Any existing EDNS content with the same option code will be overwritten.
Subsequent rules are processed after this action.
:param int option: The EDNS option number
class TestAdvancedSetEDNSOptionAction(DNSDistTest):
_config_template = """
- addAction("setednsoption.advanced.tests.powerdns.com.", SetEDNSOptionAction(10, "deadbeefdeadc0de"))
+ addAction(AllRule(), SetEDNSOptionAction(10, "deadbeefdeadc0de"))
newServer{address="127.0.0.1:%s"}
"""
self.assertEqual(expectedQuery, receivedQuery)
self.checkResponseNoEDNS(response, receivedResponse)
self.checkQueryEDNS(expectedQuery, receivedQuery)
+
+ def testAdvancedSetEDNSOptionOverwrite(self):
+ """
+ Advanced: Set EDNS Option overwrites an existing option
+ """
+ name = 'setednsoption-overwrite.advanced.tests.powerdns.com.'
+ initialECO = cookiesoption.CookiesOption(b'aaaaaaaa', b'bbbbbbbb')
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ overWrittenECO = cookiesoption.CookiesOption(b'deadbeef', b'deadc0de')
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=512, options=[overWrittenECO])
+
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkResponseNoEDNS(response, receivedResponse)
+ self.checkQueryEDNS(expectedQuery, receivedQuery)