std::string d_value;
};
+class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+ ClearRecordTypesResponseAction() {}
+
+ ClearRecordTypesResponseAction(std::set<QType> qtypes) : d_qtypes(qtypes)
+ {
+ }
+
+ DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+ {
+ if (d_qtypes.size() > 0) {
+ clearDNSPacketRecordTypes(dr->getMutableData(), d_qtypes);
+ }
+ return DNSResponseAction::Action::HeaderModify;
+ }
+
+ std::string toString() const override
+ {
+ return "clear record types";
+ }
+
+private:
+ std::set<QType> d_qtypes{};
+};
+
class ContinueAction : public DNSAction
{
public:
return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max));
});
+ luaCtx.writeFunction("ClearRecordTypesResponseAction", [](boost::variant<int,vector<pair<int, int>>> types) {
+ std::set<QType> qtypes{};
+ if(auto t = boost::get<int>(types)) {
+ qtypes.insert(t);
+ } else {
+ const auto& v = boost::get<vector<pair<int,int>>>(types);
+ for(const auto& tpair: v) {
+ qtypes.insert(tpair.second);
+ }
+ }
+ return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(qtypes));
+ });
+
luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) {
auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode));
auto rca = std::dynamic_pointer_cast<RCodeAction>(ret);
Let these packets go through.
+.. function::ClearRecordTypesResponseAction(types)
+
+ .. versionadded:: 1.8.0
+
+ Removes given type(s) records from the response.
+
+ :param int types: a single type or a list of types to remove
+
.. function:: ContinueAction(action)
.. versionadded:: 1.4.0
}
}
-void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
-{
- size_t finalsize = packet.size();
- clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
- packet.resize(finalsize);
-}
-
-void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set<QType>& qtypes)
+int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::set<QType>& qtypes)
{
- size_t finalsize = packet.size();
- clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
- packet.resize(finalsize);
-}
+ static const std::set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::LUA, QType::ENT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA};
-// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
-void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes)
-{
- if (length < sizeof(dnsheader)) {
- return;
+ if (initialPacket.size() < sizeof(dnsheader)) {
+ return EINVAL;
}
- try
- {
- dnsheader* dh = reinterpret_cast<dnsheader*>(packet);
- uint64_t ancount = ntohs(dh->ancount);
- uint64_t nscount = ntohs(dh->nscount);
- uint64_t arcount = ntohs(dh->arcount);
- DNSPacketMangler dpm(packet, length);
+ try {
+ const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+ if (ntohs(dh->qdcount) == 0)
+ return ENOENT;
+ auto packetView = pdns_string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size());
+
+ PacketReader pr(packetView);
+
+ size_t idx = 0;
+ DNSName rrname;
+ uint16_t qdcount = ntohs(dh->qdcount);
+ uint16_t ancount = ntohs(dh->ancount);
+ uint16_t nscount = ntohs(dh->nscount);
+ uint16_t arcount = ntohs(dh->arcount);
+ uint16_t rrtype;
+ uint16_t rrclass;
+ string blob;
+ struct dnsrecordheader ah;
- for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) {
- dpm.skipDomainName();
- /* type and class */
- dpm.skipBytes(4);
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+
+ GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+ pw.getHeader()->id=dh->id;
+ pw.getHeader()->qr=dh->qr;
+ pw.getHeader()->aa=dh->aa;
+ pw.getHeader()->tc=dh->tc;
+ pw.getHeader()->rd=dh->rd;
+ pw.getHeader()->ra=dh->ra;
+ pw.getHeader()->ad=dh->ad;
+ pw.getHeader()->cd=dh->cd;
+ pw.getHeader()->rcode=dh->rcode;
+
+ /* consume remaining qd if any */
+ if (qdcount > 1) {
+ for(idx = 1; idx < qdcount; idx++) {
+ rrname = pr.getName();
+ rrtype = pr.get16BitInt();
+ rrclass = pr.get16BitInt();
+ (void) rrtype;
+ (void) rrclass;
+ }
+ }
+
+ /* copy AN */
+ for (idx = 0; idx < ancount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+ pr.xfrBlob(blob);
+
+ if (qtypes.find(ah.d_type) == qtypes.end()) {
+ // if this is not a safe type
+ if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+ // "unsafe" types might countain compressed data, so cancel rewrite
+ newContent.clear();
+ return EIO;
+ }
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+ pw.xfrBlob(blob);
+ }
+ }
+
+ /* copy NS */
+ for (idx = 0; idx < nscount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+ pr.xfrBlob(blob);
+
+ if (qtypes.find(ah.d_type) == qtypes.end()) {
+ if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+ // "unsafe" types might countain compressed data, so cancel rewrite
+ newContent.clear();
+ return EIO;
+ }
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+ pw.xfrBlob(blob);
+ }
}
- auto processSection = [&dpm, &qtypes, &length](uint64_t& count) {
- for (uint64_t n=0; n < count; ++n) {
- uint32_t start = dpm.getOffset();
- dpm.skipDomainName();
-
- const auto dnstype = QType{dpm.get16BitInt()};
- /* class, ttl */
- dpm.skipBytes(6);
- dpm.skipRData();
- if (qtypes.find(dnstype) != qtypes.end()) {
- uint32_t rlen = dpm.getOffset() - start;
- dpm.rewindBytes(rlen);
- dpm.shrinkBytes(rlen);
- // update count
- count -= 1;
- length -= rlen;
- n -= 1;
+ /* copy AR */
+ for (idx = 0; idx < arcount; idx++) {
+ rrname = pr.getName();
+ pr.getDnsrecordheader(ah);
+ pr.xfrBlob(blob);
+
+ if (qtypes.find(ah.d_type) == qtypes.end()) {
+ if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+ // "unsafe" types might countain compressed data, so cancel rewrite
+ newContent.clear();
+ return EIO;
}
+ pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+ pw.xfrBlob(blob);
}
- };
- processSection(ancount);
- dh->ancount = htons(ancount);
- processSection(nscount);
- dh->nscount = htons(nscount);
- processSection(arcount);
- dh->arcount = htons(arcount);
+ }
+ pw.commit();
+
}
- catch(...)
+ catch (...)
{
- return;
+ newContent.clear();
+ return EIO;
+ }
+ return 0;
+}
+
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
+{
+ return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes);
+}
+
+void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set<QType>& qtypes)
+{
+ PacketBuffer newContent;
+
+ auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes);
+ if (!result) {
+ packet = std::move(newContent);
}
}
pwR.xfrIP6(std::string(reinterpret_cast<const char*>(v6.sin6.sin6_addr.s6_addr), 16));
pwR.commit();
- pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY);
- pwR.commit();
-
pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL);
pwR.xfrIP(v4.sin4.sin_addr.s_addr);
pwR.commit();
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
- BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
std::set<QType> toremove{QType::AAAA};
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
- BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
toremove = {QType::A};
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
- BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
packet = generatePacket();
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
- BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
toremove = {QType::A, QType::AAAA};
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
- BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
}
newServer{address="127.0.0.1:%s"}
- addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaResponseAction(luafct))
+ addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaFFIResponseAction(luafct))
+ addResponseAction("clear-records-type.responses.tests.powerdns.com.", ClearRecordTypesResponseAction(DNSQType.AAAA))
"""
def testClearedFFI(self):
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertEqual(expectedResponse, receivedResponse)
+
+ def testCleared(self):
+ """
+ Responses: Removes records of a given type
+ """
+ name = 'clear-records-type.responses.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '192.0.2.1')
+ response.answer.append(rrset)
+ expectedResponse.answer.append(rrset)
+ rrset = dns.rrset.from_text(name,
+ 3660,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '2001:DB8::1', '2001:DB8::2')
+ response.answer.append(rrset)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(expectedResponse, receivedResponse)