]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist-ecs.cc
dnsdist: Add SetNegativeAndSOAAction() and its Lua binding
[thirdparty/pdns.git] / pdns / dnsdist-ecs.cc
index 56422e48a8a992ca5cceb60be62598c072924144..7cda963d4a1e27d56480f9512e0a1e70d8bd4eec 100644 (file)
@@ -117,6 +117,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
       pr.xfrBlob(blob);
       pw.xfrBlob(blob);
     } else {
+
       pr.skip(ah.d_clen);
     }
   }
@@ -125,6 +126,192 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
   return 0;
 }
 
+static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+  for (auto it = options.begin(); it != options.end(); ) {
+    if (it->first == EDNSOptionCode::ECS) {
+      ecsAdded = false;
+
+      if (!overrideExisting) {
+        return false;
+      }
+
+      it = options.erase(it);
+    }
+    else {
+      ++it;
+    }
+  }
+
+  options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
+  return true;
+}
+
+static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+  assert(initialPacket.size() >= sizeof(dnsheader));
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+  ecsAdded = false;
+  ednsAdded = true;
+
+  if (ntohs(dh->qdcount) == 0) {
+    return false;
+  }
+
+  if (ntohs(dh->arcount) == 0) {
+    throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
+  }
+
+  PacketReader pr(initialPacket);
+
+  size_t idx = 0;
+  DNSName rrname;
+  uint16_t qdcount = ntohs(dh->qdcount);
+  uint16_t ancount = ntohs(dh->ancount);
+  uint16_t nscount = ntohs(dh->nscount);
+  uint16_t arcount = ntohs(dh->arcount);
+  uint16_t rrtype;
+  uint16_t rrclass;
+  string blob;
+  struct dnsrecordheader ah;
+
+  rrname = pr.getName();
+  rrtype = pr.get16BitInt();
+  rrclass = pr.get16BitInt();
+
+  DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+  pw.getHeader()->id=dh->id;
+  pw.getHeader()->qr=dh->qr;
+  pw.getHeader()->aa=dh->aa;
+  pw.getHeader()->tc=dh->tc;
+  pw.getHeader()->rd=dh->rd;
+  pw.getHeader()->ra=dh->ra;
+  pw.getHeader()->ad=dh->ad;
+  pw.getHeader()->cd=dh->cd;
+  pw.getHeader()->rcode=dh->rcode;
+
+  /* consume remaining qd if any */
+  if (qdcount > 1) {
+    for(idx = 1; idx < qdcount; idx++) {
+      rrname = pr.getName();
+      rrtype = pr.get16BitInt();
+      rrclass = pr.get16BitInt();
+      (void) rrtype;
+      (void) rrclass;
+    }
+  }
+
+  /* copy AN and NS */
+  for (idx = 0; idx < ancount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  for (idx = 0; idx < nscount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  /* consume AR, looking for OPT */
+  for (idx = 0; idx < arcount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    if (ah.d_type != QType::OPT) {
+      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+      pr.xfrBlob(blob);
+      pw.xfrBlob(blob);
+    } else {
+
+      ednsAdded = false;
+      pr.xfrBlob(blob);
+
+      std::vector<std::pair<uint16_t, std::string>> options;
+      getEDNSOptionsFromContent(blob, options);
+
+      EDNS0Record edns0;
+      static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+      memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
+
+      /* addOrReplaceECSOption will set it to false if there is already an existing option */
+      ecsAdded = true;
+      addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
+      pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
+    }
+  }
+
+  if (ednsAdded) {
+    pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
+    ecsAdded = true;
+  }
+
+  pw.commit();
+
+  return true;
+}
+
+static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr<std::map<uint16_t, EDNSOptionView> >& options)
+{
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+  if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0)
+ {
+    return false;
+  }
+
+  if (ntohs(dh->arcount) == 0) {
+    throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
+  }
+
+  try {
+    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+    DNSPacketMangler dpm(const_cast<char*>(packet), len);
+    uint64_t n;
+    for(n=0; n < ntohs(dh->qdcount) ; ++n) {
+      dpm.skipDomainName();
+      /* type and class */
+      dpm.skipBytes(4);
+    }
+
+    for(n=0; n < numrecords; ++n) {
+      dpm.skipDomainName();
+
+      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
+      uint16_t dnstype = dpm.get16BitInt();
+      dpm.get16BitInt();
+      dpm.skipBytes(4); /* TTL */
+
+      if(section == 3 && dnstype == QType::OPT) {
+        uint32_t offset = dpm.getOffset();
+        if (offset >= len) {
+          return false;
+        }
+        /* if we survive this call, we can parse it safely */
+        dpm.skipRData();
+        return getEDNSOptions(packet + offset, len - offset, *options) == 0;
+      }
+      else {
+        dpm.skipRData();
+      }
+    }
+  }
+  catch(...)
+  {
+    return false;
+  }
+
+  return true;
+}
+
 int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
 {
   assert(optStart != NULL);
@@ -257,10 +444,10 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
-  dh.d_clen = htons((uint16_t) optRData.length());
+  dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
   res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
-  res.assign((const char *) &name, sizeof name);
-  res.append((const char *) &dh, sizeof dh);
+  res.assign(reinterpret_cast<const char *>(&name), sizeof name);
+  res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
   res.append(optRData.c_str(), optRData.length());
 }
 
@@ -315,6 +502,11 @@ bool parseEDNSOptions(DNSQuestion& dq)
   }
 
   dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
+
+  if (ntohs(dq.dh->ancount) != 0 || ntohs(dq.dh->nscount) != 0 || (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1)) {
+    return slowParseEDNSOptions(reinterpret_cast<const char*>(dq.dh), dq.len, dq.ednsOptions);
+  }
+
   const char* packet = reinterpret_cast<const char*>(dq.dh);
 
   size_t remaining = 0;
@@ -329,7 +521,7 @@ bool parseEDNSOptions(DNSQuestion& dq)
   return false;
 }
 
-static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded)
+static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded)
 {
   /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
   /* getEDNSOptionsStart has already checked that there is exactly one AR,
@@ -348,12 +540,12 @@ static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uin
 
   memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
   *len += newECSOptionSize;
-  *ecsAdded = true;
+  ecsAdded = true;
 
   return true;
 }
 
-static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData)
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
 {
   /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
   string EDNSRR;
@@ -378,27 +570,50 @@ static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t
   uint16_t arcount = ntohs(dh->arcount);
   arcount++;
   dh->arcount = htons(arcount);
-  *ednsAdded = true;
+  ednsAdded = true;
+  ecsAdded = true;
 
   memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
 
   return true;
 }
 
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
 {
   assert(packet != nullptr);
   assert(len != nullptr);
   assert(consumed <= (size_t) *len);
-  assert(ednsAdded != nullptr);
-  assert(ecsAdded != nullptr);
+
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+  if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
+    vector<uint8_t> newContent;
+    newContent.reserve(packetSize);
+
+    if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+      ednsAdded = false;
+      ecsAdded = false;
+      return false;
+    }
+
+    if (newContent.size() > packetSize) {
+      ednsAdded = false;
+      ecsAdded = false;
+      return false;
+    }
+
+    memcpy(packet, &newContent.at(0), newContent.size());
+    *len = newContent.size();
+    return true;
+  }
+
   uint16_t optRDPosition = 0;
   size_t remaining = 0;
 
   int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
 
   if (res != 0) {
-    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData);
+    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData);
   }
 
   unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
@@ -422,7 +637,7 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
   return true;
 }
 
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData)
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
 {
   assert(dq.remote != nullptr);
   string newECSOption;
@@ -464,9 +679,7 @@ static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16
 
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
 {
-  /* we need at least:
-     root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (*optLen < 11) {
+  if (*optLen < optRecordMinimumSize) {
     return EINVAL;
   }
   const unsigned char* end = (const unsigned char*) optStart + *optLen;
@@ -490,23 +703,21 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
 
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
 {
-  /* we need at least:
-   root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (optLen < 11) {
+  if (optLen < optRecordMinimumSize) {
     return false;
   }
   size_t p = optStart + 9;
-  uint16_t rdLen = (0x100*packet.at(p) + packet.at(p+1));
+  uint16_t rdLen = (0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1)));
   p += sizeof(rdLen);
-  if (rdLen > (optLen - 11)) {
+  if (rdLen > (optLen - optRecordMinimumSize)) {
     return false;
   }
 
   size_t rdEnd = p + rdLen;
   while ((p + 4) <= rdEnd) {
-    const uint16_t optionCode = 0x100*packet.at(p) + packet.at(p+1);
+    const uint16_t optionCode = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
     p += sizeof(optionCode);
-    const uint16_t optionLen = 0x100*packet.at(p) + packet.at(p+1);
+    const uint16_t optionLen = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
     p += sizeof(optionLen);
 
     if ((p + optionLen) > rdEnd) {
@@ -628,10 +839,6 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
 
 bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
 {
-  if (dh->arcount != 0) {
-    return false;
-  }
-
   std::string optRecord;
   generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK);
 
@@ -642,7 +849,101 @@ bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uin
   char * optPtr = reinterpret_cast<char*>(dh) + len;
   memcpy(optPtr, optRecord.data(), optRecord.size());
   len += optRecord.size();
-  dh->arcount = htons(1);
+  dh->arcount = htons(ntohs(dh->arcount) + 1);
+
+  return true;
+}
+
+/*
+  This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
+  generating a NXD or NODATA answer with a SOA record in the additional section.
+*/
+bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum)
+{
+  if (ntohs(dq.dh->qdcount) != 1) {
+    return false;
+  }
+
+  assert(dq.consumed == dq.qname->wirelength());
+  size_t queryPartSize = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
+  if (dq.len < queryPartSize) {
+    /* something is already wrong, don't build on flawed foundations */
+    return false;
+  }
+
+  size_t available = dq.size - queryPartSize;
+  uint16_t qtype = htons(QType::SOA);
+  uint16_t qclass = htons(QClass::IN);
+  uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
+  size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
+
+  if (soaSize > available) {
+    /* not enough space left to add the SOA, sorry! */
+    return false;
+  }
+
+  bool hadEDNS = false;
+  bool dnssecOK = false;
+
+  if (g_addEDNSToSelfGeneratedResponses) {
+    uint16_t payloadSize = 0;
+    uint16_t z = 0;
+    hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &payloadSize, &z);
+    if (hadEDNS) {
+      dnssecOK = z & EDNS_HEADER_FLAG_DO;
+    }
+  }
+
+  /* chop off everything after the question */
+  dq.len = queryPartSize;
+  if (nxd) {
+    dq.dh->rcode = RCode::NXDomain;
+  }
+  else {
+    dq.dh->rcode = RCode::NoError;
+  }
+  dq.dh->qr = true;
+  dq.dh->ancount = 0;
+  dq.dh->nscount = 0;
+  dq.dh->arcount = 0;
+
+  rdLength = htons(rdLength);
+  ttl = htonl(ttl);
+  serial = htonl(serial);
+  refresh = htonl(refresh);
+  retry = htonl(retry);
+  expire = htonl(expire);
+  minimum = htonl(minimum);
+
+  std::string soa;
+  soa.reserve(soaSize);
+  soa.append(zone.toDNSString());
+  soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
+  soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
+  soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
+  soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
+  soa.append(mname.toDNSString());
+  soa.append(rname.toDNSString());
+  soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
+  soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
+  soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
+  soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
+  soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
+
+  if (soa.size() != soaSize) {
+    throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
+  }
+
+  memcpy(reinterpret_cast<char*>(dq.dh) + queryPartSize, soa.c_str(), soa.size());
+
+  dq.len += soa.size();
+
+  dq.dh->arcount = htons(1);
+
+  if (g_addEDNSToSelfGeneratedResponses) {
+    /* now we need to add a new OPT record */
+    return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
+  }
 
   return true;
 }
@@ -741,3 +1042,31 @@ bool queryHasEDNS(const DNSQuestion& dq)
 
   return false;
 }
+
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
+{
+  uint16_t optStart;
+  size_t optLen = 0;
+  bool last = false;
+  const char * packet = reinterpret_cast<const char*>(dq.dh);
+  std::string packetStr(packet, dq.len);
+  int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+  if (res != 0) {
+    // no EDNS OPT RR
+    return false;
+  }
+
+  if (optLen < optRecordMinimumSize) {
+    return false;
+  }
+
+  if (optStart < dq.len && packetStr.at(optStart) != 0) {
+    // OPT RR Name != '.'
+    return false;
+  }
+
+  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
+  memcpy(&edns0, packet + optStart + 5, sizeof edns0);
+  return true;
+}