]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix ECS addition when the OPT record is not the last one
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 19 Jul 2019 14:33:10 +0000 (16:33 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 7 Oct 2019 15:46:18 +0000 (17:46 +0200)
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist.cc
pdns/dnsparser.cc
pdns/dnsparser.hh
pdns/ednsoptions.cc
pdns/ednsoptions.hh
pdns/test-dnsdist_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_EdnsClientSubnet.py

index 67fc5733728184d442c8793e3a91e35a5da3ced2..5b97f4d1e88b40c521634ab4c1af2377bbac04ac 100644 (file)
@@ -117,6 +117,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
       pr.xfrBlob(blob);
       pw.xfrBlob(blob);
     } else {
+
       pr.skip(ah.d_clen);
     }
   }
@@ -125,6 +126,192 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
   return 0;
 }
 
+static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+  for (auto it = options.begin(); it != options.end(); ) {
+    if (it->first == EDNSOptionCode::ECS) {
+      ecsAdded = false;
+
+      if (!overrideExisting) {
+        return false;
+      }
+
+      it = options.erase(it);
+    }
+    else {
+      ++it;
+    }
+  }
+
+  options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
+  return true;
+}
+
+static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+{
+  assert(initialPacket.size() >= sizeof(dnsheader));
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+  ecsAdded = false;
+  ednsAdded = true;
+
+  if (ntohs(dh->qdcount) == 0) {
+    return false;
+  }
+
+  if (ntohs(dh->arcount) == 0) {
+    throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
+  }
+
+  PacketReader pr(initialPacket);
+
+  size_t idx = 0;
+  DNSName rrname;
+  uint16_t qdcount = ntohs(dh->qdcount);
+  uint16_t ancount = ntohs(dh->ancount);
+  uint16_t nscount = ntohs(dh->nscount);
+  uint16_t arcount = ntohs(dh->arcount);
+  uint16_t rrtype;
+  uint16_t rrclass;
+  string blob;
+  struct dnsrecordheader ah;
+
+  rrname = pr.getName();
+  rrtype = pr.get16BitInt();
+  rrclass = pr.get16BitInt();
+
+  DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+  pw.getHeader()->id=dh->id;
+  pw.getHeader()->qr=dh->qr;
+  pw.getHeader()->aa=dh->aa;
+  pw.getHeader()->tc=dh->tc;
+  pw.getHeader()->rd=dh->rd;
+  pw.getHeader()->ra=dh->ra;
+  pw.getHeader()->ad=dh->ad;
+  pw.getHeader()->cd=dh->cd;
+  pw.getHeader()->rcode=dh->rcode;
+
+  /* consume remaining qd if any */
+  if (qdcount > 1) {
+    for(idx = 1; idx < qdcount; idx++) {
+      rrname = pr.getName();
+      rrtype = pr.get16BitInt();
+      rrclass = pr.get16BitInt();
+      (void) rrtype;
+      (void) rrclass;
+    }
+  }
+
+  /* copy AN and NS */
+  for (idx = 0; idx < ancount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  for (idx = 0; idx < nscount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+    pr.xfrBlob(blob);
+    pw.xfrBlob(blob);
+  }
+
+  /* consume AR, looking for OPT */
+  for (idx = 0; idx < arcount; idx++) {
+    rrname = pr.getName();
+    pr.getDnsrecordheader(ah);
+
+    if (ah.d_type != QType::OPT) {
+      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+      pr.xfrBlob(blob);
+      pw.xfrBlob(blob);
+    } else {
+
+      ednsAdded = false;
+      pr.xfrBlob(blob);
+
+      std::vector<std::pair<uint16_t, std::string>> options;
+      getEDNSOptionsFromContent(blob, options);
+
+      EDNS0Record edns0;
+      static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+      memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
+
+      /* addOrReplaceECSOption will set it to false if there is already an existing option */
+      ecsAdded = true;
+      addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
+      pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
+    }
+  }
+
+  if (ednsAdded) {
+    pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
+    ecsAdded = true;
+  }
+
+  pw.commit();
+
+  return true;
+}
+
+static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr<std::map<uint16_t, EDNSOptionView> >& options)
+{
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+  if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0)
+ {
+    return false;
+  }
+
+  if (ntohs(dh->arcount) == 0) {
+    throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
+  }
+
+  try {
+    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
+    DNSPacketMangler dpm(const_cast<char*>(packet), len);
+    uint64_t n;
+    for(n=0; n < ntohs(dh->qdcount) ; ++n) {
+      dpm.skipDomainName();
+      /* type and class */
+      dpm.skipBytes(4);
+    }
+
+    for(n=0; n < numrecords; ++n) {
+      dpm.skipDomainName();
+
+      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
+      uint16_t dnstype = dpm.get16BitInt();
+      dpm.get16BitInt();
+      dpm.skipBytes(4); /* TTL */
+
+      if(section == 3 && dnstype == QType::OPT) {
+        uint32_t offset = dpm.getOffset();
+        if (offset >= len) {
+          return false;
+        }
+        /* if we survive this call, we can parse it safely */
+        dpm.skipRData();
+        return getEDNSOptions(packet + offset, len - offset, *options) == 0;
+      }
+      else {
+        dpm.skipRData();
+      }
+    }
+  }
+  catch(...)
+  {
+    return false;
+  }
+
+  return true;
+}
+
 int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
 {
   assert(optStart != NULL);
@@ -315,6 +502,11 @@ bool parseEDNSOptions(DNSQuestion& dq)
   }
 
   dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
+
+  if (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1) {
+    return slowParseEDNSOptions(reinterpret_cast<const char*>(dq.dh), dq.len, dq.ednsOptions);
+  }
+
   const char* packet = reinterpret_cast<const char*>(dq.dh);
 
   size_t remaining = 0;
@@ -329,7 +521,7 @@ bool parseEDNSOptions(DNSQuestion& dq)
   return false;
 }
 
-static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool* const ecsAdded)
+static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded)
 {
   /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
   /* getEDNSOptionsStart has already checked that there is exactly one AR,
@@ -348,12 +540,12 @@ static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uin
 
   memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
   *len += newECSOptionSize;
-  *ecsAdded = true;
+  ecsAdded = true;
 
   return true;
 }
 
-static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool* const ednsAdded, bool preserveTrailingData)
+static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
 {
   /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
   string EDNSRR;
@@ -378,27 +570,50 @@ static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t
   uint16_t arcount = ntohs(dh->arcount);
   arcount++;
   dh->arcount = htons(arcount);
-  *ednsAdded = true;
+  ednsAdded = true;
+  ecsAdded = true;
 
   memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
 
   return true;
 }
 
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
+bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
 {
   assert(packet != nullptr);
   assert(len != nullptr);
   assert(consumed <= (size_t) *len);
-  assert(ednsAdded != nullptr);
-  assert(ecsAdded != nullptr);
+
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+
+  if (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1) {
+    vector<uint8_t> newContent;
+    newContent.reserve(packetSize);
+
+    if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+      ednsAdded = false;
+      ecsAdded = false;
+      return false;
+    }
+
+    if (newContent.size() > packetSize) {
+      ednsAdded = false;
+      ecsAdded = false;
+      return false;
+    }
+
+    memcpy(packet, &newContent.at(0), newContent.size());
+    *len = newContent.size();
+    return true;
+  }
+
   uint16_t optRDPosition = 0;
   size_t remaining = 0;
 
   int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
 
   if (res != 0) {
-    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, preserveTrailingData);
+    return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData);
   }
 
   unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
@@ -422,7 +637,7 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u
   return true;
 }
 
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData)
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
 {
   assert(dq.remote != nullptr);
   string newECSOption;
index 767575723f059ca4320dfb89c88a055fb37fb300..b339269064aeea4cd25bb072911b3e377cc34abb 100644 (file)
@@ -38,8 +38,8 @@ bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const s
 bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
 
-bool handleEDNSClientSubnet(DNSQuestion& dq, bool* ednsAdded, bool* ecsAdded, bool preserveTrailingData);
-bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool* const ednsAdded, bool* const ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData);
+bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData);
+bool handleEDNSClientSubnet(char* packet, size_t packetSize, unsigned int consumed, uint16_t* len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData);
 
 bool parseEDNSOptions(DNSQuestion& dq);
 
index 7e16700c80f0f910fe32339c5fa0b03cfe7ef6d6..85a6108d694ae2d3e8d812f6794b27c90d4c022b 100644 (file)
@@ -183,7 +183,7 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult
       std::string newECSOption;
       generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() :  dq->ecsPrefixLength);
 
-      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, &ednsAdded, &ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) {
+      if (!handleEDNSClientSubnet(const_cast<char*>(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, ednsAdded, ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) {
         return DNSAction::Action::None;
       }
 
index 1c3f54a62f659a74a8e3f9065e5d00f0009a331b..a5978cae5949de23c4dffeaf051d61f50758d016 100644 (file)
@@ -1489,7 +1489,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
         }
       }
 
-      if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
+      if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded, g_preserveTrailingData)) {
         vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
         return ProcessQueryResult::Drop;
       }
index 0fd335edba431d4ddd08b8607655e0b267a5ddc3..f4b5e816f42421a56f75e8ecb83e78220f7393a9 100644 (file)
@@ -575,122 +575,6 @@ string simpleCompress(const string& elabel, const string& root)
   return ret;
 }
 
-
-/** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
- *  If you survive that, feel free to read from the pointer */
-class DNSPacketMangler
-{
-public:
-  explicit DNSPacketMangler(std::string& packet)
-    : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
-  {}
-  DNSPacketMangler(char* packet, size_t length)
-    : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
-  {}
-  
-  /*! Advances past a wire-format domain name
-   * The name is not checked for adherence to length restrictions.
-   * Compression pointers are not followed.
-   */
-  void skipDomainName()
-  {
-    uint8_t len; 
-    while((len=get8BitInt())) { 
-      if(len >= 0xc0) { // extended label
-        get8BitInt();
-        return;
-      }
-      skipBytes(len);
-    }
-  }
-
-  void skipBytes(uint16_t bytes)
-  {
-    moveOffset(bytes);
-  }
-  void rewindBytes(uint16_t by)
-  {
-    rewindOffset(by);
-  }
-  uint32_t get32BitInt()
-  {
-    const char* p = d_packet + d_offset;
-    moveOffset(4);
-    uint32_t ret;
-    memcpy(&ret, (void*)p, sizeof(ret));
-    return ntohl(ret);
-  }
-  uint16_t get16BitInt()
-  {
-    const char* p = d_packet + d_offset;
-    moveOffset(2);
-    uint16_t ret;
-    memcpy(&ret, (void*)p, sizeof(ret));
-    return ntohs(ret);
-  }
-  
-  uint8_t get8BitInt()
-  {
-    const char* p = d_packet + d_offset;
-    moveOffset(1);
-    return *p;
-  }
-  
-  void skipRData()
-  {
-    int toskip = get16BitInt();
-    moveOffset(toskip);
-  }
-
-  void decreaseAndSkip32BitInt(uint32_t decrease)
-  {
-    const char *p = d_packet + d_offset;
-    moveOffset(4);
-
-    uint32_t tmp;
-    memcpy(&tmp, (void*) p, sizeof(tmp));
-    tmp = ntohl(tmp);
-    tmp-=decrease;
-    tmp = htonl(tmp);
-    memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
-  }
-  void setAndSkip32BitInt(uint32_t value)
-  {
-    moveOffset(4);
-
-    value = htonl(value);
-    memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
-  }
-  uint32_t getOffset() const
-  {
-    return d_offset;
-  }
-private:
-  void moveOffset(uint16_t by)
-  {
-    d_notyouroffset += by;
-    if(d_notyouroffset > d_length)
-      throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " 
-      + std::to_string(d_length) );
-  }
-  void rewindOffset(uint16_t by)
-  {
-    if(d_notyouroffset < by)
-      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
-                              + std::to_string(by));
-    d_notyouroffset -= by;
-    if(d_notyouroffset < 12)
-      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
-                              + std::to_string(12));
-  }
-  char* d_packet;
-  size_t d_length;
-  
-  uint32_t d_notyouroffset;  // only 'moveOffset' can touch this
-  const uint32_t&  d_offset; // look.. but don't touch
-  
-};
-
 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
 void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
 {
index 33228bdac2344dd4a3f9b509449c3b7002c8ab15..3bb22f3b7ef631f84aaa54a722eee2228f3740b0 100644 (file)
@@ -408,4 +408,123 @@ std::shared_ptr<T> getRR(const DNSRecord& dr)
   return std::dynamic_pointer_cast<T>(dr.d_content);
 }
 
+/** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
+ *  If you survive that, feel free to read from the pointer */
+class DNSPacketMangler
+{
+public:
+  explicit DNSPacketMangler(std::string& packet)
+    : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
+  {}
+  DNSPacketMangler(char* packet, size_t length)
+    : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
+  {}
+
+  /*! Advances past a wire-format domain name
+   * The name is not checked for adherence to length restrictions.
+   * Compression pointers are not followed.
+   */
+  void skipDomainName()
+  {
+    uint8_t len;
+    while((len=get8BitInt())) {
+      if(len >= 0xc0) { // extended label
+        get8BitInt();
+        return;
+      }
+      skipBytes(len);
+    }
+  }
+
+  void skipBytes(uint16_t bytes)
+  {
+    moveOffset(bytes);
+  }
+  void rewindBytes(uint16_t by)
+  {
+    rewindOffset(by);
+  }
+  uint32_t get32BitInt()
+  {
+    const char* p = d_packet + d_offset;
+    moveOffset(4);
+    uint32_t ret;
+    memcpy(&ret, (void*)p, sizeof(ret));
+    return ntohl(ret);
+  }
+  uint16_t get16BitInt()
+  {
+    const char* p = d_packet + d_offset;
+    moveOffset(2);
+    uint16_t ret;
+    memcpy(&ret, (void*)p, sizeof(ret));
+    return ntohs(ret);
+  }
+
+  uint8_t get8BitInt()
+  {
+    const char* p = d_packet + d_offset;
+    moveOffset(1);
+    return *p;
+  }
+
+  void skipRData()
+  {
+    int toskip = get16BitInt();
+    moveOffset(toskip);
+  }
+
+  void decreaseAndSkip32BitInt(uint32_t decrease)
+  {
+    const char *p = d_packet + d_offset;
+    moveOffset(4);
+
+    uint32_t tmp;
+    memcpy(&tmp, (void*) p, sizeof(tmp));
+    tmp = ntohl(tmp);
+    tmp-=decrease;
+    tmp = htonl(tmp);
+    memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
+  }
+
+  void setAndSkip32BitInt(uint32_t value)
+  {
+    moveOffset(4);
+
+    value = htonl(value);
+    memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
+  }
+
+  uint32_t getOffset() const
+  {
+    return d_offset;
+  }
+
+private:
+  void moveOffset(uint16_t by)
+  {
+    d_notyouroffset += by;
+    if(d_notyouroffset > d_length)
+      throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
+      + std::to_string(d_length) );
+  }
+
+  void rewindOffset(uint16_t by)
+  {
+    if(d_notyouroffset < by)
+      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
+                              + std::to_string(by));
+    d_notyouroffset -= by;
+    if(d_notyouroffset < 12)
+      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
+                              + std::to_string(12));
+  }
+
+  char* d_packet;
+  size_t d_length;
+
+  uint32_t d_notyouroffset;  // only 'moveOffset' can touch this
+  const uint32_t&  d_offset; // look.. but don't touch
+};
+
 #endif
index 6b4ec1098ff2c8ac7d0607992abde92616177770..d20755be732723363bcbd1dfe4f61cabe08f03ef 100644 (file)
@@ -105,6 +105,29 @@ int getEDNSOptions(const char* optRR, const size_t len, EDNSOptionViewMap& optio
   return 0;
 }
 
+bool getEDNSOptionsFromContent(const std::string& content, std::vector<std::pair<uint16_t, std::string>>& options)
+{
+  size_t pos = 0;
+  uint16_t code, len;
+  const size_t contentLength = content.size();
+
+  while (pos < contentLength && (contentLength - pos) >= (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
+    code = (static_cast<unsigned char>(content.at(pos)) * 256) + static_cast<unsigned char>(content.at(pos+1));
+    pos += EDNS_OPTION_CODE_SIZE;
+    len = (static_cast<unsigned char>(content.at(pos)) * 256) + static_cast<unsigned char>(content.at(pos+1));
+    pos += EDNS_OPTION_LENGTH_SIZE;
+
+    if (pos > contentLength || len > (contentLength - pos)) {
+      return false;
+    }
+
+    options.emplace_back(code, std::string(&content.at(pos), len));
+    pos += len;
+  }
+
+  return true;
+}
+
 void generateEDNSOption(uint16_t optionCode, const std::string& payload, std::string& res)
 {
   const uint16_t ednsOptionCode = htons(optionCode);
index 019ac9bb9b01d343d7f8fdf62b00dd7ac19f2f42..4c8a330cebc311a879d8eacef6761f476755d722 100644 (file)
@@ -47,6 +47,8 @@ typedef std::map<uint16_t, EDNSOptionView> EDNSOptionViewMap;
 
 /* extract all EDNS0 options from a pointer on the beginning rdLen of the OPT RR */
 int getEDNSOptions(const char* optRR, size_t len, EDNSOptionViewMap& options);
+/* extract all EDNS0 options from the content (so after rdLen) of the OPT RR */
+bool getEDNSOptionsFromContent(const std::string& content, std::vector<std::pair<uint16_t, std::string>>& options);
 
 void generateEDNSOption(uint16_t optionCode, const std::string& payload, std::string& res);
 
index 14d463db31b4928ec5f71cd9c5004d6e7a3185e2..dfa035c4dea19d5ceb094553db3a6fbd4ced7436 100644 (file)
@@ -42,7 +42,7 @@ BOOST_AUTO_TEST_SUITE(test_dnsdist_cc)
 static const uint16_t ECSSourcePrefixV4 = 24;
 static const uint16_t ECSSourcePrefixV6 = 56;
 
-static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false)
+static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false, uint16_t additionals=0)
 {
   MOADNSParser mdp(true, packet, packetSize);
 
@@ -51,7 +51,7 @@ static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=t
   BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U);
   BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U);
   BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0U);
-  uint16_t expectedARCount = 0 + (hasEdns ? 1 : 0) + (hasXPF ? 1 : 0);
+  uint16_t expectedARCount = additionals + (hasEdns ? 1U : 0U) + (hasXPF ? 1U : 0U);
   BOOST_CHECK_EQUAL(mdp.d_header.arcount, expectedARCount);
 }
 
@@ -231,10 +231,10 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
   BOOST_CHECK(static_cast<size_t>(len) > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
-  BOOST_CHECK_EQUAL(ecsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
   validateQuery(packet, len);
   validateECS(packet, len, remote);
   vector<uint8_t> queryWithEDNS;
@@ -250,7 +250,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -273,11 +273,11 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
     packet[len + idx] = 'A';
   }
   len += trailingDataSize;
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
   BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size());
   BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
   BOOST_CHECK_EQUAL(ednsAdded, true);
-  BOOST_CHECK_EQUAL(ecsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
   validateQuery(packet, len);
 
   /* packet with trailing data (preserving trailing data) */
@@ -296,14 +296,14 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS)
     packet[len + idx] = 'A';
   }
   len += trailingDataSize;
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, true));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, true));
   BOOST_REQUIRE_EQUAL(static_cast<size_t>(len), queryWithEDNS.size() + trailingDataSize);
   BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0);
   for (size_t idx = 0; idx < trailingDataSize; idx++) {
     BOOST_CHECK_EQUAL(packet[queryWithEDNS.size() + idx], 'A');
   }
   BOOST_CHECK_EQUAL(ednsAdded, true);
-  BOOST_CHECK_EQUAL(ecsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
   validateQuery(packet, len);
 }
 
@@ -335,10 +335,10 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK(!parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
   BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, true);
-  BOOST_CHECK_EQUAL(ecsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
   validateQuery(packet, dq.len);
   validateECS(packet, dq.len, remote);
 
@@ -352,7 +352,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed)
   BOOST_CHECK(qclass == QClass::IN);
   DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -384,7 +384,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, true);
@@ -400,7 +400,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, false, newECSOption, false));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -436,7 +436,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK(parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
   BOOST_CHECK_GT(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, true);
@@ -453,7 +453,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) {
   BOOST_CHECK(qclass == QClass::IN);
   DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast<dnsheader*>(query.data()), query.size(), query.size(), false, nullptr);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(dq2, &ednsAdded, &ecsAdded, false));
+  BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq2.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -491,7 +491,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -536,7 +536,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) {
   BOOST_CHECK(parseEDNSOptions(dq));
 
   /* And now we add our own ECS */
-  BOOST_CHECK(handleEDNSClientSubnet(dq, &ednsAdded, &ecsAdded, false));
+  BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false));
   BOOST_CHECK_EQUAL(static_cast<size_t>(dq.len), query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -575,7 +575,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
   BOOST_CHECK((size_t) len < query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -614,7 +614,7 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
   BOOST_CHECK((size_t) len > query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
@@ -630,13 +630,227 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) {
   BOOST_CHECK_EQUAL(qname, name);
   BOOST_CHECK(qtype == QType::A);
 
-  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, &ednsAdded, &ecsAdded, true, newECSOption, false));
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
   BOOST_CHECK_EQUAL((size_t) len, query.size());
   BOOST_CHECK_EQUAL(ednsAdded, false);
   BOOST_CHECK_EQUAL(ecsAdded, false);
   validateQuery(reinterpret_cast<char*>(query.data()), len);
 }
 
+BOOST_AUTO_TEST_CASE(replaceECSFollowedByTSIG) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com.");
+  ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, 8);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
+  pw.addOpt(512, 0, 0, opts);
+  pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+  pw.commit();
+  uint16_t len = query.size();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, len, true, false, 1);
+  validateECS(packet, len, remote);
+
+  /* not large enough packet */
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 1);
+}
+
+BOOST_AUTO_TEST_CASE(replaceECSBetweenTwoRecords) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com.");
+  ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  EDNSSubnetOpts ecsOpts;
+  ecsOpts.source = Netmask(origRemote, 8);
+  string origECSOption = makeEDNSSubnetOptsString(ecsOpts);
+  DNSPacketWriter::optvect_t opts;
+  opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption));
+  pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false);
+  pw.xfr32BitInt(0x01020304);
+  pw.addOpt(512, 0, 0, opts);
+  pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+  pw.commit();
+  uint16_t len = query.size();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(packet, len, true, false, 2);
+  validateECS(packet, len, remote);
+
+  /* not large enough packet */
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 2);
+}
+
+BOOST_AUTO_TEST_CASE(insertECSInEDNSBetweenTwoRecords) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com.");
+  ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  pw.startRecord(DNSName("additional"), QType::A, 0, QClass::IN, DNSResourceRecord::ADDITIONAL, false);
+  pw.xfr32BitInt(0x01020304);
+  pw.addOpt(512, 0, 0);
+  pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+  pw.commit();
+  uint16_t len = query.size();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
+  validateQuery(packet, len, true, false, 2);
+  validateECS(packet, len, remote);
+
+  /* not large enough packet */
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), len, true, false, 2);
+}
+
+BOOST_AUTO_TEST_CASE(insertECSAfterTSIG) {
+  bool ednsAdded = false;
+  bool ecsAdded = false;
+  ComboAddress remote("192.168.1.25");
+  DNSName name("www.powerdns.com.");
+  ComboAddress origRemote("127.0.0.1");
+  string newECSOption;
+  generateECSOption(remote, newECSOption, remote.sin4.sin_family == AF_INET ? ECSSourcePrefixV4 : ECSSourcePrefixV6);
+
+  vector<uint8_t> query;
+  DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0);
+  pw.getHeader()->rd = 1;
+  pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false);
+  pw.commit();
+  uint16_t len = query.size();
+
+  /* large enough packet */
+  char packet[1500];
+  memcpy(packet, query.data(), query.size());
+
+  unsigned int consumed = 0;
+  uint16_t qtype;
+  DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK((size_t) len > query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, true);
+  BOOST_CHECK_EQUAL(ecsAdded, true);
+  /* the MOADNSParser does not allow anything except XPF after a TSIG */
+  BOOST_CHECK_THROW(validateQuery(packet, len, true, false, 1), MOADNSException);
+  validateECS(packet, len, remote);
+
+  /* not large enough packet */
+  ednsAdded = false;
+  ecsAdded = false;
+  consumed = 0;
+  len = query.size();
+  qname = DNSName(reinterpret_cast<char*>(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed);
+  BOOST_CHECK_EQUAL(qname, name);
+  BOOST_CHECK(qtype == QType::A);
+
+  BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast<char*>(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false));
+  BOOST_CHECK_EQUAL((size_t) len, query.size());
+  BOOST_CHECK_EQUAL(ednsAdded, false);
+  BOOST_CHECK_EQUAL(ecsAdded, false);
+  validateQuery(reinterpret_cast<char*>(query.data()), len, true, false);
+}
+
 BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) {
   DNSName name("www.powerdns.com.");
 
index f5e0b92ee674c1ea8986b7c7658556b3714abd76..d7c1b3b3633afa58af1d4cb057c02179a2d4b603 100644 (file)
@@ -544,6 +544,9 @@ class DNSDistTest(unittest.TestCase):
         if withCookies:
             for option in received.options:
                 self.assertEquals(option.otype, 10)
+        else:
+            for option in received.options:
+                self.assertNotEquals(option.otype, 10)
 
     def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
         self.assertEquals(expected, received)
index 241783fba29efbed8ebf1f6e0645a026bd261765..b01e6bc38b47f142b566f10b6661e743716536a7 100644 (file)
@@ -204,6 +204,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
         response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse])
         expectedResponse = dns.message.make_response(query)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -242,6 +243,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
         response.use_edns(edns=True, payload=4096, options=[ecsoResponse, ecoResponse])
         expectedResponse = dns.message.make_response(query, our_payload=4096)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -280,6 +282,7 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         ecsoResponse = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24, scope=24)
         response.use_edns(edns=True, payload=4096, options=[ecoResponse, ecsoResponse, ecoResponse])
         expectedResponse = dns.message.make_response(query, our_payload=4096)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse, ecoResponse])
         rrset = dns.rrset.from_text(name,
                                     3600,
                                     dns.rdataclass.IN,
@@ -482,6 +485,97 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
             self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
             self.checkResponseEDNSWithECS(response, receivedResponse)
 
+    def testWithECSFollowedByAnother(self):
+        """
+        ECS: Existing EDNS with ECS, followed by another record
+
+        Send a query with EDNS and an existing ECS value.
+        The OPT record is not the last one in the query
+        and is followed by another one.
+        Check that the query received by the responder
+        has a valid ECS value and that the response
+        received from dnsdist contains an EDNS pseudo-RR.
+        """
+        name = 'withecs-followedbyanother.ecs.tests.powerdns.com.'
+        ecso = clientsubnetoption.ClientSubnetOption('192.0.2.1', 24)
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+        rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,ecso,eco])
+        # I would have loved to use a TSIG here but I can't find how to make dnspython ignore
+        # it while parsing the message in the receiver :-/
+        query.additional.append(rrset)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,eco,rewrittenEcso])
+        expectedQuery.additional.append(rrset)
+
+        response = dns.message.make_response(expectedQuery)
+        response.use_edns(edns=True, payload=4096, options=[eco, ecso, eco])
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[eco, ecso, eco])
+        response.answer.append(rrset)
+        response.additional.append(rrset)
+        expectedResponse.answer.append(rrset)
+        expectedResponse.additional.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 2)
+            self.checkResponseEDNSWithECS(expectedResponse, receivedResponse, 2)
+
+    def testWithEDNSNoECSFollowedByAnother(self):
+        """
+        ECS: Existing EDNS without ECS, followed by another record
+
+        Send a query with EDNS but no ECS value.
+        The OPT record is not the last one in the query
+        and is followed by another one.
+        Check that the query received by the responder
+        has a valid ECS value and that the response
+        received from dnsdist contains an EDNS pseudo-RR.
+        """
+        name = 'withedns-no-ecs-followedbyanother.ecs.tests.powerdns.com.'
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
+        rewrittenEcso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
+        # I would have loved to use a TSIG here but I can't find how to make dnspython ignore
+        # it while parsing the message in the receiver :-/
+        query.additional.append(rrset)
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco,rewrittenEcso])
+        expectedQuery.additional.append(rrset)
+
+        response = dns.message.make_response(expectedQuery)
+        response.use_edns(edns=True, payload=4096, options=[eco, rewrittenEcso, eco])
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.use_edns(edns=True, payload=4096, options=[eco, eco])
+        response.answer.append(rrset)
+        response.additional.append(rrset)
+        expectedResponse.answer.append(rrset)
+        expectedResponse.additional.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, 2)
+
 class TestECSDisabledByRuleOrLua(DNSDistTest):
     """
     dnsdist is configured to add the EDNS0 Client Subnet