]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #7653 from pieterlexis/docker-ignore
authorPieter Lexis <pieterlexis@users.noreply.github.com>
Fri, 12 Apr 2019 10:12:08 +0000 (12:12 +0200)
committerGitHub <noreply@github.com>
Fri, 12 Apr 2019 10:12:08 +0000 (12:12 +0200)
Add dockerignore

166 files changed:
README.md
builder-support/dockerfiles/Dockerfile.rpmbuild
configure.ac
docs/backends/generic-mysql.rst
docs/backends/generic-postgresql.rst
docs/changelog/4.1.rst
docs/dnssec/modes-of-operation.rst
docs/manpages/dnswasher.1.rst
docs/manpages/ixfrdist.yml.5.rst
docs/manpages/pdnsutil.1.rst
docs/secpoll.zone
docs/settings.rst
ext/Makefile.am
ext/ipcrypt/.gitignore [new file with mode: 0644]
ext/ipcrypt/LICENSE [new file with mode: 0644]
ext/ipcrypt/Makefile.am [new file with mode: 0644]
ext/ipcrypt/ipcrypt.c [new file with mode: 0644]
ext/ipcrypt/ipcrypt.h [new file with mode: 0644]
m4/pdns_check_libcrypto.m4
modules/gmysqlbackend/gmysqlbackend.cc
modules/gmysqlbackend/smysql.cc
modules/gmysqlbackend/smysql.hh
modules/lmdbbackend/lmdb-safe.cc
modules/lmdbbackend/lmdb-safe.hh
modules/lmdbbackend/lmdb-typed.hh
modules/lmdbbackend/lmdbbackend.cc
pdns/Makefile.am
pdns/README-dnsdist.md
pdns/backends/gsql/gsqlbackend.cc
pdns/common_startup.cc
pdns/devpollmplexer.cc
pdns/distributor.hh
pdns/dnscrypt.cc
pdns/dnscrypt.hh
pdns/dnsdist-cache.cc
pdns/dnsdist-carbon.cc
pdns/dnsdist-console.cc
pdns/dnsdist-dynblocks.hh
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua-vars.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist-web.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/configure.ac
pdns/dnsdistdist/dnsdist-idstate.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/advanced/axfr.rst
pdns/dnsdistdist/docs/advanced/luaaction.rst
pdns/dnsdistdist/docs/advanced/timedipsetrule.rst
pdns/dnsdistdist/docs/advanced/tuning.rst
pdns/dnsdistdist/docs/changelog.rst
pdns/dnsdistdist/docs/guides/cache.rst
pdns/dnsdistdist/docs/guides/serverselection.rst
pdns/dnsdistdist/docs/reference/comboaddress.rst
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/docs/reference/constants.rst
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/docs/reference/tuning.rst
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsdistdist/docs/statistics.rst
pdns/dnsdistdist/docs/upgrade_guide.rst
pdns/dnsdistdist/ext/ipcrypt/LICENSE [new symlink]
pdns/dnsdistdist/ext/ipcrypt/Makefile.am [new symlink]
pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c [new symlink]
pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h [new symlink]
pdns/dnsdistdist/html/local.js
pdns/dnsdistdist/ipcipher.cc [new symlink]
pdns/dnsdistdist/ipcipher.hh [new symlink]
pdns/dnsdistdist/tcpiohandler.cc
pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc
pdns/dnsdistdist/test-dnsdistrules_cc.cc
pdns/dnsdistdist/test-mplexer.cc [new symlink]
pdns/dnswasher.cc
pdns/dolog.hh
pdns/epollmplexer.cc
pdns/ipcipher.cc [new file with mode: 0644]
pdns/ipcipher.hh [new file with mode: 0644]
pdns/iputils.cc
pdns/iputils.hh
pdns/ixfrdist-web.cc
pdns/ixfrdist-web.hh
pdns/ixfrdist.cc
pdns/ixfrdist.example.yml
pdns/kqueuemplexer.cc
pdns/lua-recursor4.cc
pdns/misc.cc
pdns/mplexer.hh
pdns/mtasker.cc
pdns/mtasker.hh
pdns/packethandler.cc
pdns/pdns_recursor.cc
pdns/pdnsutil.cc
pdns/pollmplexer.cc
pdns/portsmplexer.cc
pdns/rec-snmp.cc
pdns/rec_channel_rec.cc
pdns/recursordist/Makefile.am
pdns/recursordist/README.md
pdns/recursordist/RECURSOR-MIB.txt
pdns/recursordist/configure.ac
pdns/recursordist/docs/changelog/4.1.rst
pdns/recursordist/docs/manpages/rec_control.1.rst
pdns/recursordist/docs/metrics.rst
pdns/recursordist/docs/settings.rst
pdns/recursordist/test-ednsoptions_cc.cc
pdns/recursordist/test-mplexer.cc [new symlink]
pdns/recursordist/test-syncres_cc.cc
pdns/responsestats-auth.cc
pdns/speedtest.cc
pdns/sstuff.hh
pdns/statbag.cc
pdns/statbag.hh
pdns/statnode.hh
pdns/syncres.cc
pdns/syncres.hh
pdns/tcpiohandler.hh
pdns/tcpreceiver.cc
pdns/test-dnsdist_cc.cc
pdns/test-ipcrypt_cc.cc [new file with mode: 0644]
pdns/test-mplexer.cc [new file with mode: 0644]
pdns/test-packetcache_cc.cc
pdns/webserver.cc
pdns/webserver.hh
pdns/ws-auth.cc
pdns/ws-recursor.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/requirements.txt
regression-tests.dnsdist/runtests
regression-tests.dnsdist/test-include-dir/test.conf
regression-tests.dnsdist/test_API.py
regression-tests.dnsdist/test_AXFR.py
regression-tests.dnsdist/test_Advanced.py
regression-tests.dnsdist/test_Basics.py
regression-tests.dnsdist/test_CacheHitResponses.py
regression-tests.dnsdist/test_Caching.py
regression-tests.dnsdist/test_Carbon.py
regression-tests.dnsdist/test_CheckConfig.py
regression-tests.dnsdist/test_DNSCrypt.py
regression-tests.dnsdist/test_Dnstap.py
regression-tests.dnsdist/test_DynBlocks.py
regression-tests.dnsdist/test_EDNSOptions.py
regression-tests.dnsdist/test_EDNSSelfGenerated.py
regression-tests.dnsdist/test_EdnsClientSubnet.py
regression-tests.dnsdist/test_HealthChecks.py
regression-tests.dnsdist/test_Protobuf.py
regression-tests.dnsdist/test_RecordsCount.py
regression-tests.dnsdist/test_Responses.py
regression-tests.dnsdist/test_Routing.py
regression-tests.dnsdist/test_Spoofing.py
regression-tests.dnsdist/test_TCPKeepAlive.py
regression-tests.dnsdist/test_TCPLimits.py
regression-tests.dnsdist/test_TLS.py
regression-tests.dnsdist/test_Tags.py
regression-tests.dnsdist/test_TeeAction.py
regression-tests.dnsdist/test_Trailing.py
regression-tests.recursor-dnssec/test_ECS.py
regression-tests/backends/lmdb-slave

index 3f85d438f6d38e35e5dee6d71407b760dd4b20e6..74a3b840b0ecbbe09385361be9aaa2e268fef93e 100644 (file)
--- a/README.md
+++ b/README.md
@@ -96,11 +96,11 @@ If you run into C++11-related symbol trouble, please try passing `CPPFLAGS=-D_GL
 
 Compiling the Recursor
 ----------------------
-See the README in pdns/recursordist.
+See [README.md](pdns/recursordist/README.md) in `pdns/recursordist/`.
 
 Compiling dnsdist
 -----------------
-See the README in pdns/dnsdistdist.
+See [README-dnsdist.md](pdns/README-dnsdist.md) in `pdns/`.
 
 Building the HTML documentation
 -------------------------------
index b51e0c528501247fc133c83af5fab641e7e89bfa..91a2bd57b721ab4233e7ba89af8d6c33fc4e72eb 100644 (file)
@@ -1,5 +1,5 @@
 FROM dist-base as package-builder
-RUN yum install -y rpm-build rpmdevtools python34 && \
+RUN yum install -y rpm-build rpmdevtools /usr/bin/python3 && \
     yum groupinstall -y "Development Tools" && \
     rpmdev-setuptree
 
index 8ac60b12c51fdcc345e5ac25d43e5d17ad58f99f..1c00fff9756336949b49656bffb9098010cf281d 100644 (file)
@@ -107,7 +107,7 @@ PDNS_CHECK_LIBCRYPTO_EDDSA
 PDNS_CHECK_RAGEL([pdns/dnslabeltext.cc], [www.powerdns.com])
 PDNS_CHECK_CLOCK_GETTIME
 
-BOOST_REQUIRE([1.35])
+BOOST_REQUIRE([1.42])
 # Boost accumulators, as used by dnsbulktest and dnstcpbench, need 1.48+
 # to be compatible with C++11
 AM_CONDITIONAL([HAVE_BOOST_GE_148], [test "$boost_major_version" -ge 148])
@@ -297,6 +297,8 @@ AC_SUBST([AM_CPPFLAGS],
 
 AC_SUBST([YAHTTP_CFLAGS], ['-I$(top_srcdir)/ext/yahttp'])
 AC_SUBST([YAHTTP_LIBS], ['$(top_builddir)/ext/yahttp/yahttp/libyahttp.la'])
+AC_SUBST([IPCRYPT_CFLAGS], ['-I$(top_srcdir)/ext/ipcrypt'])
+AC_SUBST([IPCRYPT_LIBS], ['$(top_builddir)/ext/ipcrypt/libipcrypt.la'])
 
 CXXFLAGS="$SANITIZER_FLAGS $CXXFLAGS"
 
@@ -315,6 +317,7 @@ AC_CONFIG_FILES([
   docs/Makefile
   pdns/pdns.init
   ext/Makefile
+  ext/ipcrypt/Makefile
   ext/yahttp/Makefile
   ext/yahttp/yahttp/Makefile
   ext/json11/Makefile
index c04070d58c30588f8af551215c9d2cdcc82def75..02d39fe999daab71f491c8aafaa3dab87584db54 100644 (file)
@@ -123,6 +123,16 @@ Use the InnoDB READ-COMMITTED transaction isolation level. Default: yes.
 The timeout in seconds for each attempt to read from, or write to the
 server. A value of 0 will disable the timeout. Default: 10
 
+.. _setting-gmysql-thread-cleanup:
+
+``gmysql-thread-cleanup``
+^^^^^^^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.1.8
+
+Older versions (such as those shipped on RHEL 7) of the MySQL/MariaDB client libraries leak memory unless applications explicitly report the end of each thread to the library. Enabling ``gmysql-thread-cleanup`` tells PowerDNS to call ``mysql_thread_end()`` whenever a thread ends.
+
+Only enable this if you are certain you need to. For more discussion, see https://github.com/PowerDNS/pdns/issues/6231.
+
 Default Schema
 --------------
 
index cb396fb7e916e895cd591ba3aed8be385cb767bb..824233413d129067b2bfdd0804f3f5f801df305b 100644 (file)
@@ -79,9 +79,9 @@ The password to for :ref:`setting-gpgsql-user`. Default: not set.
 
 Enable DNSSEC processing for this backend. Default: no.
 
-.. _setting-gpsql-extra-connection-parameters:
+.. _setting-gpgsql-extra-connection-parameters:
 
-``gpsql-extra-connection-parameters``
+``gpgsql-extra-connection-parameters``
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Extra connection parameters to forward to postgres. If you want to pin a
index ac1236b6364f105fa0fb2c32dea1ed6ada19e040..0b37c2614c98c5384efa2e7d28fca77e756f1ea4 100644 (file)
@@ -1,6 +1,79 @@
 Changelogs for 4.1.x
 ====================
 
+.. changelog::
+  :version: 4.1.8
+  :released: March 22nd 2019
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7604
+    :tickets: 7494
+
+    Correctly interpret an empty AXFR response to an IXFR query.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7610
+    :tickets: 7341
+
+    Fix replying from ANY address for non-standard port.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7609
+    :tickets: 7580
+
+    Fix rectify for ENT records in narrow zones.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7607
+    :tickets: 7472
+
+    Do not compress the root.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7608
+    :tickets: 7459
+
+    Fix dot stripping in ``setcontent()``.
+
+  .. change::
+    :tags: Bug Fixes, MySQL
+    :pullreq: 7605
+    :tickets: 7496
+
+    Fix invalid SOA record in MySQL which prevented the authoritative server from starting.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7603
+    :tickets: 7294
+
+    Prevent leak of file descriptor if running out of ports for incoming AXFR.
+
+  .. change::
+    :tags: Bug Fixes, API
+    :pullreq: 7602
+    :tickets: 7546
+
+    Fix API search failed with "Commands out of sync; you can't run this command now".
+
+  .. change::
+    :tags: Bug Fixes, MySQL
+    :pullreq: 7509
+    :tickets: 7517
+
+    Plug ``mysql_thread_init`` memory leak.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7567
+
+    EL6: fix ``CXXFLAGS`` to build with compiler optimizations.
+
 .. changelog::
   :version: 4.1.7
   :released: March 18th 2019
index c89b6e55713fbca9a5f5916d8306a0dc57a5ddf4..f6679dafbd3fe8ebe5f10dfb8603b34823f7bc91 100644 (file)
@@ -7,7 +7,7 @@ authoritative server. PowerDNS supports this mode fully.
 
 In addition, PowerDNS supports taking care of the signing itself, in
 which case PowerDNS operates differently from most tutorials and
-handbooks. This mode is easier however.
+handbooks. This mode is easier, however.
 
 For relevant tradeoffs, please see :doc:`../security` and
 :doc:`../performance`.
@@ -18,16 +18,16 @@ Online Signing
 --------------
 
 In the simplest situation, there is a single "SQL" database that
-contains, in separate tables, all domain data, keying material and other
+contains, in separate tables, all domain data, keying material, and other
 DNSSEC related settings.
 
 This database is then replicated to all PowerDNS instances, which all
-serve identical records, keys and signatures.
+serve identical records, keys, and signatures.
 
 In this mode of operation, care should be taken that the database
 replication occurs over a secure network, or over an encrypted
-connection. This is because keying material, if intercepted, could be
-used to counterfeit DNSSEC data using the original keys.
+connection. If intercepted, keying material could be used to counterfeit
+DNSSEC data using the original keys.
 
 Such a single replicated database requires no further attention beyond
 monitoring already required during non-DNSSEC operations.
@@ -45,17 +45,17 @@ Zone Signing Keys (ZSKs). During normal operations, this means that only
 1 ZSK is 'active', and the other is inactive.
 
 Should it be desired to 'roll over' to a new key, both keys can
-temporarily be active (and used for signing), and after a while the old
-key can be inactivated. Subsequently it can be removed.
+temporarily be active (and used for signing), and after a while, the old
+key can be deactivated. Subsequently, it can be removed.
 
-As elucidated above, there are several ways in which DNSSEC can deny the
-existence of a record, and this setting too is stored away from zone
-records, and lives with the DNSSEC keying material.
+As described above, there are several ways in which DNSSEC can deny the
+existence of a record, and this setting, which is also stored away from zone
+records, lives with the DNSSEC keying material.
 
 (Hashed) Denial of Existence
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-PowerDNS supports unhashed secure denial of existence using NSEC
+PowerDNS supports unhashed secure denial-of-existence using NSEC
 records. These are generated with the help of the (database) backend,
 which needs to be able to supply the 'previous' and 'next' records in
 canonical ordering.
@@ -63,7 +63,7 @@ canonical ordering.
 The Generic SQL Backends have fields that allow them to supply these
 relative record names.
 
-In addition, hashed secure denial of existence is supported using NSEC3
+In addition, hashed secure denial-of-existence is supported using NSEC3
 records, in two modes, one with help from the database, the other with
 the help of some additional calculations.
 
@@ -72,7 +72,7 @@ where the backend should be able to supply the previous and next domain
 names in hashed order.
 
 NSEC3 in 'narrow' mode uses additional hashing calculations to provide
-hashed secure denial of existence 'on the fly', without further
+hashed secure denial-of-existence 'on the fly', without further
 involving the database.
 
 .. _dnssec-signatures:
@@ -84,8 +84,8 @@ In PowerDNS live signing mode, signatures, as served through RRSIG
 records, are calculated on the fly, and heavily cached. All CPU cores
 are used for the calculation.
 
-RRSIGs have a validity period, in PowerDNS this period is 3 weeks.
-This period starts at most a week in the past, and continues at least a week into the future.
+RRSIGs have a validity period. In PowerDNS, the RRSIG validity period is 3 weeks.
+This period starts at most a week in the past and continues at least a week into the future.
 This interval jumps with one-week increments every Thursday.
 
 The time period used is always calculated based on the moment of rollover.
@@ -105,7 +105,7 @@ Graphically, it looks like this::
                               |----- RRSIG(1) served -----|----- RRSIG(2) served -----|
 
   |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
-  thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu
+  Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu
                                                           ^
                                                           |
                                                           RRSIG roll-over(1 to 2)
@@ -115,8 +115,8 @@ At all times, only one RRSIG per signed RRset per ZSK is served when responding
 .. note::
   Why Thursday? POSIX-based operating systems count the time
   since GMT midnight January 1st of 1970, which was a Thursday. PowerDNS
-  inception/expiration times are generated based on an integral number of
-  weeks having passed since the start of the 'epoch'.
+  inception/expiration times are generated based on the integral number of
+  weeks since the start of the 'epoch'.
 
 PowerDNS also serves the DNSKEY records in live-signing mode. Their TTL
 is derived from the SOA records *minimum* field. When using NSEC3, the
@@ -127,7 +127,7 @@ Pre-signed records
 
 In this mode, PowerDNS serves zones that already contain DNSSEC records.
 Such zones can either be slaved from a remote master, or can be signed
-using tools like OpenDNSSEC, ldns-signzone or dnssec-signzone.
+using tools like OpenDNSSEC, ldns-signzone, and dnssec-signzone.
 
 Even in this mode, PowerDNS will synthesize NSEC(3) records itself
 because of its architecture. RRSIGs of these NSEC(3) will still need to
@@ -155,7 +155,7 @@ required for the receiving party to rectify the zone without knowing the
 keys, such as signed NSEC3 records for empty non-terminals. The zone is
 not required to be rectified on the master.
 
-Signatures and Hashing is similar as described in :ref:`dnssec-online-signing`.
+The signing and hashing algorithms are described in :ref:`dnssec-online-signing`.
 
 .. _dnssec-modes-bind-mode:
 
@@ -171,7 +171,7 @@ To use this mode, add
 restart PowerDNS.
 
 .. note::
-  This sqlite database is different from the database used for the regular :doc:`SQLite 3 backend <../backends/generic-sqlite3>`.
+  This SQLite database is different from the database used for the regular :doc:`SQLite 3 backend <../backends/generic-sqlite3>`.
 
 After this, you can use ``pdnsutil secure-zone`` and all other pdnsutil
 commands on your BIND zones without trouble.
@@ -182,7 +182,7 @@ Hybrid BIND-mode operation
 --------------------------
 
 PowerDNS can also operate based on 'BIND'-style zone & configuration
-files. This 'bindbackend' has full knowledge of DNSSEC, but has no
+files. This 'bindbackend' has full knowledge of DNSSEC but has no
 native way of storing keying material.
 
 However, since PowerDNS supports operation with multiple simultaneous
index 46f5de42d818ab5cca86a28ad0fb15a241cb2cd4..1b20806226cf48d2b45030e5738f1ed1e379c554 100644 (file)
@@ -25,7 +25,11 @@ about.
 Options
 -------
 
-None
+--decrypt,-d             Undo IPCipher encryption of IP addresses
+--help, -h               Show summary of options.
+--key,-k                 Base64 encoded 128-bit key for IPCipher
+--passphrase,-p          Passphrase that will be used to derive an IPCipher key
+--version,-v             Output version
 
 See also
 --------
index 22b48d80f5707f5904129cd3ca5df00552d2f30d..121a696ef2cd73b2a6af3db2b5a22849a983922a 100644 (file)
@@ -113,6 +113,14 @@ Options
   Entries without a netmask will be interpreted as a single address.
   By default, this list is set to ``127.0.0.0/8`` and ``::1/128``.
 
+:webserver-loglevel:
+  How much the webserver should log: 'none', 'normal' or 'detailed'.
+  When logging, each log-line contains the UUID of the request, this allows finding errors caused by certain requests.
+  With 'none', nothing is logged except for errors.
+  With 'normal' (the default), one line per request is logged in the style of the common log format::
+    [NOTICE] [webserver] 46326eef-b3ba-4455-8e76-15ec73879aa3 127.0.0.1:57566 "GET /metrics HTTP/1.1" 200 1846
+  with 'detailed', the full requests and responses (including headers) are logged along with the regular log-line from 'normal'.
+
 See also
 --------
 
index 2a5511665b98deeeb3c0b5c29d246c27cd481d8a..3b02ccfbc75d3e3011c41bb237477b4bc35cbe84 100644 (file)
@@ -232,6 +232,14 @@ bench-db [*FILE*]
     *FILE* can be a file with a list, one per line, of domain names to use for this.
     If *FILE* is not specified, powerdns.com is used.
 
+OTHER TOOLS
+-----------
+ipencrypt *IP-ADDRESS* passsword
+    Encrypt an IP address according to the 'ipcipher' standard
+
+ipdecrypt *IP-ADDRESS* passsword
+    Encrypt an IP address according to the 'ipcipher' standard
+
 See also
 --------
 
index aa5aed5e6ab3b01c656c0b64c6898813ff4b2300..21c25746b9623d971bfb9ac5e44616ae2d37d3a9 100644 (file)
@@ -1,4 +1,4 @@
-@       86400   IN  SOA pdns-public-ns1.powerdns.com. pieter\.lexis.powerdns.com. 2019031901 10800 3600 604800 10800
+@       86400   IN  SOA pdns-public-ns1.powerdns.com. pieter\.lexis.powerdns.com. 2019041201 10800 3600 604800 10800
 @       3600    IN  NS  pdns-public-ns1.powerdns.com.
 @       3600    IN  NS  pdns-public-ns2.powerdns.com.
 ; Auth
@@ -44,6 +44,7 @@ auth-4.1.4.security-status                              60 IN TXT "3 Upgrade now
 auth-4.1.5.security-status                              60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.1.6.security-status                              60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.1.7.security-status                              60 IN TXT "1 OK"
+auth-4.1.8.security-status                              60 IN TXT "1 OK"
 auth-4.2.0-alpha1.security-status                       60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.2.0-beta1.security-status                        60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.2.0-rc1.security-status                          60 IN TXT "1 OK"
@@ -179,6 +180,7 @@ recursor-4.1.8.security-status                          60 IN TXT "3 Upgrade now
 recursor-4.1.9.security-status                          60 IN TXT "1 OK"
 recursor-4.1.10.security-status                         60 IN TXT "1 OK"
 recursor-4.1.11.security-status                         60 IN TXT "1 OK"
+recursor-4.1.12.security-status                         60 IN TXT "1 OK"
 recursor-4.2.0-alpha1.security-status                   60 IN TXT "1 OK"
 
 ; Recursor Debian
@@ -298,3 +300,4 @@ recursor-4.0.0_beta1-1pdns.jessie.raspbian.security-status 60 IN TXT "3 Upgrade
 
 ; dnsdist
 dnsdist-1.3.3.security-status                              60 IN TXT "1 OK"
+dnsdist-1.4.0-alpha1.security-status                       60 IN TXT "1 OK"
index a622d765825064af3ffbe772bc6ee2bb8c7a3dc7..eab1244d270046cbe771b179dcf3c2b2f772292c 100644 (file)
@@ -1615,6 +1615,47 @@ IP Address for webserver/API to listen on.
 
 Webserver/API access is only allowed from these subnets.
 
+.. _setting-webserver-loglevel:
+
+``webserver-loglevel``
+----------------------
+.. versionadded:: 4.2.0
+
+-  String, one of "none", "normal", "detailed"
+
+The amount of logging the webserver must do. "none" means no useful webserver information will be logged.
+When set to "normal", the webserver will log a line per request that should be familiar::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+When set to "detailed", all information about the request and response are logged::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Request Details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-encoding: gzip, deflate
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-language: en-US,en;q=0.5
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   connection: keep-alive
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   dnt: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   host: 127.0.0.1:8081
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   upgrade-insecure-requests: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   user-agent: Mozilla/5.0 (X11; Linux x86_64; rv:64.0) Gecko/20100101 Firefox/64.0
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  No body
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Response details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Connection: close
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Length: 49
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Type: text/html; charset=utf-8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Server: PowerDNS/0.0.15896.0.gaba8bab3ab
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Full body: 
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   <!html><title>Not Found</title><h1>Not Found</h1>
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+The value between the hooks is a UUID that is generated for each request. This can be used to find all lines related to a single request.
+
+.. note::
+  The webserver logs these line on the NOTICE level. The :ref:`settings-loglevel` seting must be 5 or higher for these lines to end up in the log.
+
 .. _setting-webserver-password:
 
 ``webserver-password``
index 49dc5a219216ad3860bb46a49888ee0f280db098..7c0a42d4199fc6feaaaf3168da99b1e499b1e689 100644 (file)
@@ -1,10 +1,12 @@
 SUBDIRS = \
-       yahttp \
-        json11
+       ipcrypt \
+       json11 \
+       yahttp
 
 DIST_SUBDIRS = \
-       yahttp \
-        json11
+       ipcrypt \
+       json11 \
+       yahttp
 
 EXTRA_DIST = \
        luawrapper/include/LuaContext.hpp
diff --git a/ext/ipcrypt/.gitignore b/ext/ipcrypt/.gitignore
new file mode 100644 (file)
index 0000000..24ad051
--- /dev/null
@@ -0,0 +1,5 @@
+*.la
+*.lo
+*.o
+Makefile
+Makefile.in
diff --git a/ext/ipcrypt/LICENSE b/ext/ipcrypt/LICENSE
new file mode 100644 (file)
index 0000000..0a199e5
--- /dev/null
@@ -0,0 +1,14 @@
+Copyright (c) 2015-2018, Frank Denis <j at pureftpd dot org>
+
+Permission to use, copy, modify, and/or distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
diff --git a/ext/ipcrypt/Makefile.am b/ext/ipcrypt/Makefile.am
new file mode 100644 (file)
index 0000000..a68b2d4
--- /dev/null
@@ -0,0 +1,5 @@
+noinst_LTLIBRARIES = libipcrypt.la
+
+libipcrypt_la_SOURCES = \
+       ipcrypt.c \
+       ipcrypt.h
diff --git a/ext/ipcrypt/ipcrypt.c b/ext/ipcrypt/ipcrypt.c
new file mode 100644 (file)
index 0000000..6ef464a
--- /dev/null
@@ -0,0 +1,87 @@
+
+#include "ipcrypt.h"
+
+#define ROTL(X, R) (X) = (unsigned char) ((X) << (R)) | ((X) >> (8 - (R)))
+
+static void
+arx_fwd(unsigned char state[4])
+{
+    state[0] += state[1];
+    state[2] += state[3];
+    ROTL(state[1], 2);
+    ROTL(state[3], 5);
+    state[1] ^= state[0];
+    state[3] ^= state[2];
+    ROTL(state[0], 4);
+    state[0] += state[3];
+    state[2] += state[1];
+    ROTL(state[1], 3);
+    ROTL(state[3], 7);
+    state[1] ^= state[2];
+    state[3] ^= state[0];
+    ROTL(state[2], 4);
+}
+
+static void
+arx_bwd(unsigned char state[4])
+{
+    ROTL(state[2], 4);
+    state[1] ^= state[2];
+    state[3] ^= state[0];
+    ROTL(state[1], 5);
+    ROTL(state[3], 1);
+    state[0] -= state[3];
+    state[2] -= state[1];
+    ROTL(state[0], 4);
+    state[1] ^= state[0];
+    state[3] ^= state[2];
+    ROTL(state[1], 6);
+    ROTL(state[3], 3);
+    state[0] -= state[1];
+    state[2] -= state[3];
+}
+
+static inline void
+xor4(unsigned char *out, const unsigned char *x, const unsigned char *y)
+{
+    out[0] = x[0] ^ y[0];
+    out[1] = x[1] ^ y[1];
+    out[2] = x[2] ^ y[2];
+    out[3] = x[3] ^ y[3];
+}
+
+int
+ipcrypt_encrypt(unsigned char out[IPCRYPT_BYTES],
+                const unsigned char in[IPCRYPT_BYTES],
+                const unsigned char key[IPCRYPT_KEYBYTES])
+{
+    unsigned char state[4];
+
+    xor4(state, in, key);
+    arx_fwd(state);
+    xor4(state, state, key + 4);
+    arx_fwd(state);
+    xor4(state, state, key + 8);
+    arx_fwd(state);
+    xor4(out, state, key + 12);
+
+    return 0;
+}
+
+int
+ipcrypt_decrypt(unsigned char out[IPCRYPT_BYTES],
+                const unsigned char in[IPCRYPT_BYTES],
+                const unsigned char key[IPCRYPT_KEYBYTES])
+{
+    unsigned char state[4];
+
+    xor4(state, in, key + 12);
+    arx_bwd(state);
+    xor4(state, state, key + 8);
+    arx_bwd(state);
+    xor4(state, state, key + 4);
+    arx_bwd(state);
+    xor4(out, state, key);
+
+    return 0;
+}
diff --git a/ext/ipcrypt/ipcrypt.h b/ext/ipcrypt/ipcrypt.h
new file mode 100644 (file)
index 0000000..76b94f5
--- /dev/null
@@ -0,0 +1,24 @@
+
+#ifndef ipcrypt_H
+#define ipcrypt_H
+
+#define IPCRYPT_BYTES 4
+#define IPCRYPT_KEYBYTES 16
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int ipcrypt_encrypt(unsigned char out[IPCRYPT_BYTES],
+                    const unsigned char in[IPCRYPT_BYTES],
+                    const unsigned char key[IPCRYPT_KEYBYTES]);
+
+int ipcrypt_decrypt(unsigned char out[IPCRYPT_BYTES],
+                    const unsigned char in[IPCRYPT_BYTES],
+                    const unsigned char key[IPCRYPT_KEYBYTES]);
+
+#ifdef __cplusplus
+}  /* End of the 'extern "C"' block */
+#endif
+
+#endif
index b0e6a39e4cf1b2f6eb5b4223f23e3d790429d099..c71c98acc8208cb22fee669e51cc582f2c731a0d 100644 (file)
@@ -90,6 +90,10 @@ AC_DEFUN([PDNS_CHECK_LIBCRYPTO], [
         # it will just work!
     fi
 
+    if $found; then
+        AC_DEFINE([HAVE_LIBCRYPTO], [1], [Define to 1 if you have OpenSSL libcrypto])
+    fi
+
     # try the preprocessor and linker with our new flags,
     # being careful not to pollute the global LIBS, LDFLAGS, and CPPFLAGS
 
@@ -120,4 +124,5 @@ AC_DEFUN([PDNS_CHECK_LIBCRYPTO], [
     AC_SUBST([LIBCRYPTO_INCLUDES])
     AC_SUBST([LIBCRYPTO_LIBS])
     AC_SUBST([LIBCRYPTO_LDFLAGS])
+    AM_CONDITIONAL([HAVE_LIBCRYPTO], [test "x$LIBCRYPTO_LIBS" != "x"])
 ])
index e3033abe3b1a7a2a68b5ca93b01b17cf4b45451c..48707d3a95f250ea6780d2c21484fb2a05a02968 100644 (file)
@@ -59,7 +59,8 @@ void gMySQLBackend::reconnect()
                    getArg("password"),
                    getArg("group"),
                    mustDo("innodb-read-committed"),
-                   getArgAsNum("timeout")));
+                   getArgAsNum("timeout"),
+                   mustDo("thread-cleanup")));
 }
 
 class gMySQLFactory : public BackendFactory
@@ -78,6 +79,7 @@ public:
     declare(suffix,"group", "Database backend MySQL 'group' to connect as", "client");
     declare(suffix,"innodb-read-committed","Use InnoDB READ-COMMITTED transaction isolation level","yes");
     declare(suffix,"timeout", "The timeout in seconds for each attempt to read/write to the server", "10");
+    declare(suffix,"thread-cleanup","Explicitly call mysql_thread_end() when threads end","no");
 
     declare(suffix,"dnssec","Enable DNSSEC processing","no");
 
index 1d851daa91e3094a8fc637a35f6ee1220963a2f0..0b062574776f7bfe9ccc0232c9849d5113df8364 100644 (file)
 typedef bool my_bool;
 #endif
 
+/*
+ * Older versions of the MySQL and MariaDB client leak memory
+ * because they expect the application to call mysql_thread_end()
+ * when a thread ends. This thread_local static object provides
+ * that closure, but only when the user has asked for it
+ * by setting gmysql-thread-cleanup.
+ * For more discussion, see https://github.com/PowerDNS/pdns/issues/6231
+ */
+class MySQLThreadCloser
+{
+public:
+  ~MySQLThreadCloser() {
+    if(d_enabled) {
+      mysql_thread_end();
+    }
+  }
+  void enable() {
+   d_enabled = true;
+  }
+
+private:
+  bool d_enabled = false;
+};
+
+static thread_local MySQLThreadCloser threadcloser;
+
 bool SMySQL::s_dolog;
 pthread_mutex_t SMySQL::s_myinitlock = PTHREAD_MUTEX_INITIALIZER;
 
@@ -419,6 +445,10 @@ void SMySQL::connect()
   int retry=1;
 
   Lock l(&s_myinitlock);
+  if (d_threadCleanup) {
+    threadcloser.enable();
+  }
+
   if (!mysql_init(&d_db))
     throw sPerrorException("Unable to initialize mysql driver");
 
@@ -467,8 +497,8 @@ void SMySQL::connect()
 }
 
 SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const string &msocket, const string &user,
-               const string &password, const string &group, bool setIsolation, unsigned int timeout):
-  d_database(database), d_host(host), d_msocket(msocket), d_user(user), d_password(password), d_group(group), d_timeout(timeout), d_port(port), d_setIsolation(setIsolation)
+               const string &password, const string &group, bool setIsolation, unsigned int timeout, bool threadCleanup):
+  d_database(database), d_host(host), d_msocket(msocket), d_user(user), d_password(password), d_group(group), d_timeout(timeout), d_port(port), d_setIsolation(setIsolation), d_threadCleanup(threadCleanup)
 {
   connect();
 }
index 7e31d7b5644f1da8d405318253abe22e08f6d305..7a33e8c529d12357430a4cf3418320a6a3860eca 100644 (file)
@@ -32,7 +32,8 @@ public:
   SMySQL(const string &database, const string &host="", uint16_t port=0,
          const string &msocket="",const string &user="",
          const string &password="", const string &group="",
-         bool setIsolation=false, unsigned int timeout=10);
+         bool setIsolation=false, unsigned int timeout=10,
+         bool threadCleanup=false);
 
   ~SMySQL();
 
@@ -61,6 +62,7 @@ private:
   unsigned int d_timeout;
   uint16_t d_port;
   bool d_setIsolation;
+  bool d_threadCleanup;
 };
 
 #endif /* SSMYSQL_HH */
index eea586455bb0903eb413f99d5fc2e335120ced35..094303f2533dbfab77978f059bbab4c66d7922bf 100644 (file)
@@ -26,8 +26,10 @@ MDBDbi::MDBDbi(MDB_env* env, MDB_txn* txn, const string_view dbname, int flags)
 
 MDBEnv::MDBEnv(const char* fname, int flags, int mode)
 {
-  mdb_env_create(&d_env);   
-  if(mdb_env_set_mapsize(d_env, 16ULL*4096*244140ULL)) // 4GB
+  mdb_env_create(&d_env);
+  uint64_t mapsizeMB = (sizeof(long)==4) ? 100 : 16000;
+  // on 32 bit platforms, there is just no room for more
+  if(mdb_env_set_mapsize(d_env, mapsizeMB * 1048576))
     throw std::runtime_error("setting map size");
     /*
 Various other options may also need to be set before opening the handle, e.g. mdb_env_set_mapsize(), mdb_env_set_maxreaders(), mdb_env_set_maxdbs(),
index de04d320857bbd343fe76e1c5ab350d2b2002c0c..de5e9aef850312848d10656c58432caff2a2d4b0 100644 (file)
@@ -13,7 +13,7 @@
 // apple compiler somehow has string_view even in c++11!
 #if __cplusplus < 201703L && !defined(__APPLE__)
 #include <boost/version.hpp>
-#if BOOST_VERSION > 105400
+#if BOOST_VERSION >= 106100
 #include <boost/utility/string_view.hpp>
 using boost::string_view;
 #else
index a64dd1110af1f9d4c76c8d0bf04b64de0673b6bd..e6da3dd7d21d9397c645545bcf270d4ebc887f81 100644 (file)
@@ -256,7 +256,7 @@ public:
       return count;
     }
 
-    //! End iderator type
+    //! End iterator type
     struct eiter_t
     {};
 
index 8b9ca2bc67b1cc25f4259d9b1ef85323f071ab23..0ddc65fe8087c137d815a2241ae7d23d28027bde 100644 (file)
@@ -224,10 +224,10 @@ std::shared_ptr<DNSRecordContent> unserializeContentZR(uint16_t qtype, const DNS
    Note - domain_id, name and type are ONLY present on the index!
 */
 
-#if BOOST_VERSION <= 105400
-#define StringView string
-#else
+#if BOOST_VERSION >= 106100
 #define StringView string_view
+#else
+#define StringView string
 #endif
 
 void LMDBBackend::deleteDomainRecords(RecordsRWTransaction& txn, uint32_t domain_id, uint16_t qtype)
@@ -808,7 +808,7 @@ bool LMDBBackend::setMaster(const DNSName &domain, const std::string& ips)
   vector<string> parts;
   stringtok(parts, ips, " \t;,");
   for(const auto& ip : parts) 
-    masters.push_back(ComboAddress(ip)); 
+    masters.push_back(ComboAddress(ip, 53));
   
   return genChangeDomain(domain, [&masters](DomainInfo& di) {
       di.masters = masters;
@@ -858,8 +858,17 @@ void LMDBBackend::getAllDomains(vector<DomainInfo> *domains, bool include_disabl
 
     auto txn = getRecordsROTransaction(iter.getID());
     if(!txn->txn.get(txn->db->dbi, co(di.id, g_rootdnsname, QType::SOA), val)) {
-      domains->push_back(di);
+      DNSResourceRecord rr;
+      serFromString(val.get<string_view>(), rr);
+
+      if(rr.content.size() >= 5 * sizeof(uint32_t)) {
+        uint32_t serial = *reinterpret_cast<uint32_t*>(&rr.content[rr.content.size() - (5 * sizeof(uint32_t))]);
+        di.serial = ntohl(serial);
+      }
+    } else if(!include_disabled) {
+      continue;
     }
+    domains->push_back(di);
   }
 }
 
@@ -1583,7 +1592,8 @@ public:
   {
     declare(suffix,"filename","Filename for lmdb","./pdns.lmdb");
     declare(suffix,"sync-mode","Synchronisation mode: nosync, nometasync, mapasync","mapasync");
-    declare(suffix,"shards","Records database will be split into this number of shards","64");
+    // there just is no room for more on 32 bit
+    declare(suffix,"shards","Records database will be split into this number of shards", (sizeof(long) == 4) ? "2" : "64"); 
   }
   DNSBackend *make(const string &suffix="")
   {
index f0acec88a65dbee2ce2c997a8a40ca33a246772f..6749e4cf6e15e6d3aa4608ff818209b56a531ea4 100644 (file)
@@ -223,6 +223,7 @@ pdns_server_SOURCES = \
        tsigutils.hh tsigutils.cc \
        tkey.cc \
        ueberbackend.cc ueberbackend.hh \
+       uuid-utils.hh uuid-utils.cc \
        unix_semaphore.cc \
        unix_utility.cc \
        utility.hh \
@@ -312,6 +313,7 @@ pdnsutil_SOURCES = \
        ednsoptions.cc ednsoptions.hh \
        ednssubnet.cc \
        gss_context.cc gss_context.hh \
+       ipcipher.cc ipcipher.hh \
        iputils.cc iputils.hh \
        json.cc \
        logger.cc \
@@ -349,7 +351,8 @@ pdnsutil_LDADD = \
        $(JSON11_LIBS) \
        $(LIBDL) \
        $(BOOST_PROGRAM_OPTIONS_LIBS) \
-       $(LIBCRYPTO_LIBS)
+       $(LIBCRYPTO_LIBS) \
+       $(IPCRYPT_LIBS)
 
 if LIBSODIUM
 pdnsutil_SOURCES += sodiumsigners.cc
@@ -640,6 +643,7 @@ ixfrdist_SOURCES = \
        threadname.hh threadname.cc \
        tsigverifier.cc tsigverifier.hh \
        unix_utility.cc \
+       uuid-utils.hh uuid-utils.cc \
        webserver.hh webserver.cc \
        zoneparser-tng.cc
 
@@ -880,18 +884,22 @@ speedtest_LDADD = $(LIBCRYPTO_LIBS) \
        $(RT_LIBS)
 
 dnswasher_SOURCES = \
+       base64.cc \
        dnslabeltext.cc \
        dnsname.hh dnsname.cc \
        dnsparser.hh \
        dnspcap.cc dnspcap.hh \
        dnswasher.cc \
        dnswriter.hh \
+       ipcipher.cc ipcipher.hh \
        logger.cc \
        misc.cc \
        qtype.cc \
        statbag.cc \
        unix_utility.cc
 
+dnswasher_LDFLAGS =    $(AM_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LDFLAGS) $(LIBCRYPTO_LDFLAGS)
+dnswasher_LDADD =      $(BOOST_PROGRAM_OPTIONS_LIBS) $(LIBCRYPTO_LIBS) $(IPCRYPT_LIBS)
 
 dnsbulktest_SOURCES = \
        base32.cc \
@@ -1272,6 +1280,7 @@ testrunner_SOURCES = \
        ednssubnet.cc \
        gettime.cc gettime.hh \
        gss_context.cc gss_context.hh \
+       ipcipher.cc ipcipher.hh \
        iputils.cc \
        ixfr.cc ixfr.hh \
        logger.cc \
@@ -1282,6 +1291,7 @@ testrunner_SOURCES = \
        nameserver.cc \
        nsecrecords.cc \
        opensslsigners.cc opensslsigners.hh \
+       pollmplexer.cc \
        qtype.cc \
        rcpgenerator.cc \
        responsestats.cc \
@@ -1301,11 +1311,13 @@ testrunner_SOURCES = \
        test-dnsparser_cc.cc \
        test-dnsparser_hh.cc \
        test-dnsrecords_cc.cc \
+       test-ipcrypt_cc.cc \
        test-iputils_hh.cc \
        test-ixfr_cc.cc \
        test-lock_hh.cc \
        test-lua_auth4_cc.cc \
        test-misc_hh.cc \
+       test-mplexer.cc \
        test-nameserver_cc.cc \
        test-packetcache_cc.cc \
        test-packetcache_hh.cc \
@@ -1332,7 +1344,8 @@ testrunner_LDADD = \
        $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) \
        $(RT_LIBS) \
        $(LUA_LIBS) \
-       $(LIBDL)
+       $(LIBDL) \
+       $(IPCRYPT_LIBS)
 
 if PKCS11
 testrunner_SOURCES += pkcs11signers.cc pkcs11signers.hh
@@ -1349,6 +1362,25 @@ testrunner_SOURCES += decafsigners.cc
 testrunner_LDADD += $(LIBDECAF_LIBS)
 endif
 
+if HAVE_FREEBSD
+ixfrdist_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
+endif
+
+if HAVE_LINUX
+ixfrdist_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
+endif
+
+if HAVE_SOLARIS
+ixfrdist_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
+testrunner_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
+endif
+
 pdns_control_SOURCES = \
        arguments.cc \
        dynloader.cc \
@@ -1436,6 +1468,7 @@ fuzz_target_packetcache_SOURCES = \
        ednsoptions.cc ednsoptions.hh \
        misc.cc misc.hh \
        packetcache.hh \
+       qtype.cc qtype.hh \
        statbag.cc statbag.hh
 
 fuzz_target_packetcache_DEPENDENCIES = $(fuzz_targets_deps)
index 670f460fa7a70faddee014e0a49697a69fda1e7f..f8e67cb0772b0054693347ed7def0679c53630f3 100644 (file)
@@ -21,9 +21,9 @@ Install dependencies from Homebrew:
 brew install autoconf automake boost libedit libsodium libtool lua pkg-config protobuf
 ```
 
-Let configure know where to find libedit:
+Let configure know where to find libedit, and openssl or libressl:
 
 ```sh
-./configure 'PKG_CONFIG_PATH=/usr/local/opt/libedit/lib/pkgconfig'
+./configure 'PKG_CONFIG_PATH=/usr/local/opt/libedit/lib/pkgconfig:/usr/local/opt/libressl/lib/pkgconfig'
 make
 ```
index 95f850c8116dc2b0d9c164a33fd7950558859876..64f8c758e550e7d7411f32c10f2049e638f92581 100644 (file)
@@ -930,7 +930,7 @@ bool GSQLBackend::getAllDomainMetadata(const DNSName& name, std::map<std::string
       d_GetAllDomainMetadataQuery_stmt->nextRow(row);
       ASSERT_ROW_COLUMNS("get-all-domain-metadata-query", row, 2);
 
-      if (!isDnssecDomainMetadata(row[0]))
+      if (d_dnssecQueries || !isDnssecDomainMetadata(row[0]))
         meta[row[0]].push_back(row[1]);
     }
 
index ced7c304f5d4d598477bb53952e3a248c0ecc678..e7f6f5ab95255b3ebe17a214c2420ae131aa0634 100644 (file)
@@ -150,6 +150,7 @@ void declareArguments()
   ::arg().set("webserver-port","Port of webserver/API to listen on")="8081";
   ::arg().set("webserver-password","Password required for accessing the webserver")="";
   ::arg().set("webserver-allow-from","Webserver/API access is only allowed from these subnets")="127.0.0.1,::1";
+  ::arg().set("webserver-loglevel", "Amount of logging in the webserver (none, normal, detailed)") = "normal";
 
   ::arg().setSwitch("do-ipv6-additional-processing", "Do AAAA additional processing")="yes";
   ::arg().setSwitch("query-logging","Hint backends that queries should be logged")="no";
@@ -336,11 +337,11 @@ void declareStats(void)
   S.declare("latency","Average number of microseconds needed to answer a question", getLatency);
   S.declare("timedout-packets","Number of packets which weren't answered within timeout set");
   S.declare("security-status", "Security status based on regular polling");
-  S.declareRing("queries","UDP Queries Received");
-  S.declareRing("nxdomain-queries","Queries for non-existent records within existent domains");
-  S.declareRing("noerror-queries","Queries for existing records, but for type we don't have");
-  S.declareRing("servfail-queries","Queries that could not be answered due to backend errors");
-  S.declareRing("unauth-queries","Queries for domains that we are not authoritative for");
+  S.declareDNSNameQTypeRing("queries","UDP Queries Received");
+  S.declareDNSNameQTypeRing("nxdomain-queries","Queries for non-existent records within existent domains");
+  S.declareDNSNameQTypeRing("noerror-queries","Queries for existing records, but for type we don't have");
+  S.declareDNSNameQTypeRing("servfail-queries","Queries that could not be answered due to backend errors");
+  S.declareDNSNameQTypeRing("unauth-queries","Queries for domains that we are not authoritative for");
   S.declareRing("logmessages","Log Messages");
   S.declareComboRing("remotes","Remote server IP addresses");
   S.declareComboRing("remotes-unauth","Remote hosts querying domains for which we are not auth");
@@ -422,7 +423,7 @@ try
      if(P->d.qr)
        continue;
 
-    S.ringAccount("queries", P->qdomain.toLogString()+"/"+P->qtype.getName());
+    S.ringAccount("queries", P->qdomain, P->qtype);
     S.ringAccount("remotes",P->d_remote);
     if(logDNSQueries) {
       string remote;
index f9234965a8f9c56a9dc218ecd847cc70a1c6b3ff..35df6dc8fc9940f5e498d409a4cab5659da1ce6d 100644 (file)
@@ -49,7 +49,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -82,9 +82,9 @@ DevPollFDMultiplexer::DevPollFDMultiplexer()
     
 }
 
-void DevPollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+void DevPollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct pollfd devent;
   devent.fd=fd;
@@ -160,13 +160,13 @@ int DevPollFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(dvp.dp_fds[n].fd);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't refind ourselves as writable!
     }
     d_iter=d_writeCallbacks.find(dvp.dp_fds[n].fd);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
   delete[] dvp.dp_fds;
index aa53f64a11050ef3458b5c4f41dff856fb8f1c45..ecc1509cd49525ca92f663e1c764e17845039642 100644 (file)
@@ -322,13 +322,12 @@ template<class Answer, class Question, class Backend>int MultiThreadDistributor<
   QD->Q=q;
   auto ret = QD->id = nextid++; // might be deleted after write!
   QD->callback=callback;
-  
-  if(write(d_pipes[QD->id % d_pipes.size()].second, &QD, sizeof(QD)) != sizeof(QD))
-    unixDie("write");
-
-  d_queued++;
-
 
+  ++d_queued;
+  if(write(d_pipes[QD->id % d_pipes.size()].second, &QD, sizeof(QD)) != sizeof(QD)) {
+    --d_queued;
+    unixDie("write");
+  }
 
   if(d_queued > d_maxQueueLength) {
     g_log<<Logger::Error<< d_queued <<" questions waiting for database/backend attention. Limit is "<<::arg().asNum("max-queue-length")<<", respawning"<<endl;
index f81263c09fcbbb2d5445a5f582eec1b566b98aa7..4cf126ce389d6afc7f8bb979670e34c461e458e9 100644 (file)
@@ -283,13 +283,19 @@ std::string DNSCryptContext::certificateDateToStr(uint32_t date)
   return string(buf);
 }
 
-void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active)
+void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active, bool reload)
 {
   WriteLock w(&d_lock);
 
   for (auto pair : certs) {
     if (pair->cert.getSerial() == newCert.getSerial()) {
-      throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+      if (reload) {
+        /* on reload we just assume that this is the same certificate */
+        return;
+      }
+      else {
+        throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+      }
     }
   }
 
@@ -301,7 +307,7 @@ void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCr
   certs.push_back(pair);
 }
 
-void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active)
+void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active, bool reload)
 {
   DNSCryptCert newCert;
   DNSCryptPrivateKey newPrivateKey;
@@ -309,7 +315,14 @@ void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std:
   loadCertFromFile(certFile, newCert);
   newPrivateKey.loadFromFile(keyFile);
 
-  addNewCertificate(newCert, newPrivateKey, active);
+  addNewCertificate(newCert, newPrivateKey, active, reload);
+  certificatePath = certFile;
+  keyPath = keyFile;
+}
+
+void DNSCryptContext::reloadCertificate()
+{
+  loadNewCertificate(certificatePath, keyPath, true, true);
 }
 
 void DNSCryptContext::markActive(uint32_t serial)
index 40876017cd5527f2b30227c0ce09d420f0332fb6..86ddcd20159fdad28b485fd3b97bafa7dab323bc 100644 (file)
 
 #ifndef HAVE_DNSCRYPT
 
+/* let's just define a few types and values so that the rest of
+   the code can ignore whether DNSCrypt support is available */
+#define DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE (0)
+
+class DNSCryptContext
+{
+};
+
 class DNSCryptQuery
 {
+  DNSCryptQuery(const std::shared_ptr<DNSCryptContext>& ctx): d_ctx(ctx)
+  {
+  }
+private:
+  std::shared_ptr<DNSCryptContext> d_ctx{nullptr};
 };
 
-#else
+#else /* HAVE_DNSCRYPT */
 
 #include <memory>
 #include <string>
@@ -244,8 +257,9 @@ public:
   DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile);
   DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey);
 
-  void loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active=true);
-  void addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active=true);
+  void reloadCertificate();
+  void loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active=true, bool reload=false);
+  void addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active=true, bool reload=false);
   void markActive(uint32_t serial);
   void markInactive(uint32_t serial);
   void removeInactiveCertificate(uint32_t serial);
@@ -263,6 +277,8 @@ private:
   pthread_rwlock_t d_lock;
   std::vector<std::shared_ptr<DNSCryptCertificatePair>> certs;
   DNSName providerName;
+  std::string certificatePath;
+  std::string keyPath;
 };
 
 bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DNSCryptExchangeVersion version, DNSCryptCert& certOut, DNSCryptPrivateKey& keyOut);
index a359ea8a2a0a26ba4158bc7630c97d6acdcf0169..2736dc0480ce441d59bf0d1e3facc3b681a2a15d 100644 (file)
@@ -233,7 +233,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t
     }
 
     const CacheValue& value = it->second;
-    if (value.validity < now) {
+    if (value.validity <= now) {
       if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
         d_misses++;
         return false;
@@ -314,7 +314,7 @@ size_t DNSDistPacketCache::purgeExpired(size_t upTo)
     for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
       const CacheValue& value = it->second;
 
-      if (value.validity < now) {
+      if (value.validity <= now) {
         it = map.erase(it);
         --toRemove;
         d_shards[shardIndex].d_entriesCount--;
index f8f4d943d881175c4a20446891cad90d21c46000..1fa49f58468010255209bff4ab789ac568a073d8 100644 (file)
@@ -95,6 +95,14 @@ try
           str<<base<<"latency" << ' ' << (state->availability != DownstreamState::Availability::Down ? state->latencyUsec/1000.0 : 0) << " " << now << "\r\n";
           str<<base<<"senderrors" << ' ' << state->sendErrors.load() << " " << now << "\r\n";
           str<<base<<"outstanding" << ' ' << state->outstanding.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedsendingquery" << ' '<< state->tcpDiedSendingQuery.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedreaddingresponse" << ' '<< state->tcpDiedReadingResponse.load() << " " << now << "\r\n";
+          str<<base<<"tcpgaveup" << ' '<< state->tcpGaveUp.load() << " " << now << "\r\n";
+          str<<base<<"tcpreadimeouts" << ' '<< state->tcpReadTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpwritetimeouts" << ' '<< state->tcpWriteTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpcurrentconnections" << ' '<< state->tcpCurrentConnections.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgqueriesperconnection" << ' '<< state->tcpAvgQueriesPerConnection.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgconnectionduration" << ' '<< state->tcpAvgConnectionDuration.load() << " " << now << "\r\n";
         }
         for(const auto& front : g_frontends) {
           if (front->udpFD == -1 && front->tcpFD == -1)
@@ -104,6 +112,14 @@ try
           boost::replace_all(frontName, ".", "_");
           const string base = namespace_name + "." + hostname + "." + instance_name + ".frontends." + frontName + ".";
           str<<base<<"queries" << ' ' << front->queries.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedreadingquery" << ' '<< front->tcpDiedReadingQuery.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedsendingresponse" << ' '<< front->tcpDiedSendingResponse.load() << " " << now << "\r\n";
+          str<<base<<"tcpgaveup" << ' '<< front->tcpGaveUp.load() << " " << now << "\r\n";
+          str<<base<<"tcpclientimeouts" << ' '<< front->tcpClientTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpdownstreamtimeouts" << ' '<< front->tcpDownstreamTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpcurrentconnections" << ' '<< front->tcpCurrentConnections.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgqueriesperconnection" << ' '<< front->tcpAvgQueriesPerConnection.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgconnectionduration" << ' '<< front->tcpAvgConnectionDuration.load() << " " << now << "\r\n";
         }
         auto localPools = g_pools.getLocal();
         for (const auto& entry : *localPools) {
index 87e6b10ef985ff6e19e957cb7473d35652cd732a..8fde18cb5d0b58c193f451bd7d9adbf9b040c99e 100644 (file)
@@ -347,8 +347,6 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "addDNSCryptBind", true, "\"127.0.0.1:8443\", \"provider name\", \"/path/to/resolver.cert\", \"/path/to/resolver.key\", {reusePort=false, tcpFastOpenSize=0, interface=\"\", cpus={}}", "listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of `provider name`, using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files. The fifth optional parameter is a table of parameters" },
   { "addDynBlocks", true, "addresses, message[, seconds[, action]]", "block the set of addresses with message `msg`, for `seconds` seconds (10 by default), applying `action` (default to the one set with `setDynBlocksAction()`)" },
   { "addLocal", true, "addr [, {doTCP=true, reusePort=false, tcpFastOpenSize=0, interface=\"\", cpus={}}]", "add `addr` to the list of addresses we listen on" },
-  { "addLuaAction", true, "x, func [, {uuid=\"UUID\"}]", "where 'x' is all the combinations from `addAction`, and func is a function with the parameter `dq`, which returns an action to be taken on this packet. Good for rare packets but where you want to do a lot of processing" },
-  { "addLuaResponseAction", true, "x, func [, {uuid=\"UUID\"}]", "where 'x' is all the combinations from `addAction`, and func is a function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing" },
   { "addCacheHitResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a cache hit response rule" },
   { "addResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a response rule" },
   { "addSelfAnsweredResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a self-answered response rule" },
@@ -430,6 +428,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "RCodeRule", true, "rcode", "matches responses with the specified rcode" },
   { "RegexRule", true, "regex", "matches the query name against the supplied regex" },
   { "registerDynBPFFilter", true, "DynBPFFilter", "register this dynamic BPF filter into the web interface so that its counters are displayed" },
+  { "reloadAllCertificates", true, "", "reload all DNSCrypt and TLS certificates, along with their associated keys" },
   { "RemoteLogAction", true, "RemoteLogger [, alterFunction [, serverID]]", "send the content of this query to a remote logger via Protocol Buffer. `alterFunction` is a callback, receiving a DNSQuestion and a DNSDistProtoBufMessage, that can be used to modify the Protocol Buffer content, for example for anonymization purposes. `serverID` is the server identifier." },
   { "RemoteLogResponseAction", true, "RemoteLogger [,alterFunction [,includeCNAME [, serverID]]]", "send the content of this response to a remote logger via Protocol Buffer. `alterFunction` is the same callback than the one in `RemoteLogAction` and `includeCNAME` indicates whether CNAME records inside the response should be parsed and exported. The default is to only exports A and AAAA records. `serverID` is the server identifier." },
   { "rmCacheHitResponseRule", true, "id", "remove cache hit response rule in position 'id', or whose uuid matches if 'id' is an UUID string" },
@@ -474,6 +473,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setServerPolicyLua", true, "name, function", "set server selection policy to one named 'name' and provided by 'function'" },
   { "setServFailWhenNoServer", true, "bool", "if set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query" },
   { "setStaleCacheEntriesTTL", true, "n", "allows using cache entries expired for at most n seconds when there is no backend available to answer for a query" },
+  { "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" },
   { "setTCPDownstreamCleanupInterval", true, "interval", "minimum interval in seconds between two cleanups of the idle TCP downstream connections" },
   { "setTCPUseSinglePipe", true, "bool", "whether the incoming TCP connections should be put into a single queue instead of using per-thread queues. Defaults to false" },
   { "setTCPRecvTimeout", true, "n", "set the read timeout on TCP connections from the client, in seconds" },
index 73d4af36d82a25e5d1fee16bdcccf24354704240..5d9923fb07efafbd5b188d6fd767b53ce6b4a8ce 100644 (file)
  */
 #pragma once
 
+#include <unordered_set>
+
 #include "dolog.hh"
 #include "dnsdist-rings.hh"
+#include "statnode.hh"
+
+#include "dnsdist-lua-inspection-ffi.hh"
+
+// dnsdist_ffi_stat_node_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_stat_node_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_stat_node_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+typedef std::function<bool(dnsdist_ffi_stat_node_t*)> dnsdist_ffi_stat_node_visitor_t;
+
+struct dnsdist_ffi_stat_node_t
+{
+  dnsdist_ffi_stat_node_t(const StatNode& node_, const StatNode::Stat& self_, const StatNode::Stat& children_): node(node_), self(self_), children(children_)
+  {
+  }
+
+  const StatNode& node;
+  const StatNode::Stat& self;
+  const StatNode::Stat& children;
+};
 
 class DynBlockRulesGroup
 {
@@ -119,7 +149,7 @@ private:
     bool d_enabled{false};
   };
 
-    typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
+  typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
 
 public:
   DynBlockRulesGroup()
@@ -149,6 +179,20 @@ public:
     entry = DynBlockRule(reason, blockDuration, rate, warningRate, seconds, action);
   }
 
+  typedef std::function<bool(const StatNode&, const StatNode::Stat&, const StatNode::Stat&)> smtVisitor_t;
+
+  void setSuffixMatchRule(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, smtVisitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitor = visitor;
+  }
+
+  void setSuffixMatchRuleFFI(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, dnsdist_ffi_stat_node_visitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitorFFI = visitor;
+  }
+
   void apply()
   {
     struct timespec now;
@@ -160,6 +204,7 @@ public:
   void apply(const struct timespec& now)
   {
     counts_t counts;
+    StatNode statNodeRoot;
 
     size_t entriesCount = 0;
     if (hasQueryRules()) {
@@ -171,9 +216,9 @@ public:
     counts.reserve(entriesCount);
 
     processQueryRules(counts, now);
-    processResponseRules(counts, now);
+    processResponseRules(counts, statNodeRoot, now);
 
-    if (counts.empty()) {
+    if (counts.empty() && statNodeRoot.empty()) {
       return;
     }
 
@@ -239,6 +284,38 @@ public:
     if (updated && blocks) {
       g_dynblockNMG.setState(*blocks);
     }
+
+    if (!statNodeRoot.empty()) {
+      StatNode::Stat node;
+      std::unordered_set<DNSName> namesToBlock;
+      statNodeRoot.visit([this,&namesToBlock](const StatNode* node_, const StatNode::Stat& self, const StatNode::Stat& children) {
+                           bool block = false;
+
+                           if (d_smtVisitorFFI) {
+                             dnsdist_ffi_stat_node_t tmp(*node_, self, children);
+                             block = d_smtVisitorFFI(&tmp);
+                           }
+                           else {
+                             block = d_smtVisitor(*node_, self, children);
+                           }
+
+                           if (block) {
+                             namesToBlock.insert(DNSName(node_->fullname));
+                           }
+                         },
+        node);
+
+      if (!namesToBlock.empty()) {
+        updated = false;
+        SuffixMatchTree<DynBlock> smtBlocks = g_dynblockSMT.getCopy();
+        for (const auto& name : namesToBlock) {
+          addOrRefreshBlockSMT(smtBlocks, now, name, d_suffixMatchRule, updated);
+        }
+        if (updated) {
+          g_dynblockSMT.setState(smtBlocks);
+        }
+      }
+    }
   }
 
   void excludeRange(const Netmask& range)
@@ -251,12 +328,18 @@ public:
     d_excludedSubnets.addMask(range, false);
   }
 
+  void excludeDomain(const DNSName& domain)
+  {
+    d_excludedDomains.add(domain);
+  }
+
   std::string toString() const
   {
     std::stringstream result;
 
     result << "Query rate rule: " << d_queryRateRule.toString() << std::endl;
     result << "Response rate rule: " << d_respRateRule.toString() << std::endl;
+    result << "SuffixMatch rule: " << d_suffixMatchRule.toString() << std::endl;
     result << "RCode rules: " << std::endl;
     for (const auto& rule : d_rcodeRules) {
       result << "- " << RCode::to_s(rule.first) << ": " << rule.second.toString() << std::endl;
@@ -266,6 +349,7 @@ public:
       result << "- " << QType(rule.first).getName() << ": " << rule.second.toString() << std::endl;
     }
     result << "Excluded Subnets: " << d_excludedSubnets.toString() << std::endl;
+    result << "Excluded Domains: " << d_excludedDomains.toString() << std::endl;
 
     return result.str();
   }
@@ -348,6 +432,44 @@ private:
     updated = true;
   }
 
+  void addOrRefreshBlockSMT(SuffixMatchTree<DynBlock>& blocks, const struct timespec& now, const DNSName& name, const DynBlockRule& rule, bool& updated)
+  {
+    if (d_excludedDomains.check(name)) {
+      /* do not add a block for excluded domains */
+      return;
+    }
+
+    struct timespec until = now;
+    until.tv_sec += rule.d_blockDuration;
+    unsigned int count = 0;
+    const auto& got = blocks.lookup(name);
+    bool expired = false;
+
+    if (got) {
+      if (until < got->until) {
+        // had a longer policy
+        return;
+      }
+
+      if (now < got->until) {
+        // only inherit count on fresh query we are extending
+        count = got->blocks;
+      }
+      else {
+        expired = true;
+      }
+    }
+
+    DynBlock db{rule.d_blockReason, until, name, rule.d_action};
+    db.blocks = count;
+
+    if (!d_beQuiet && (!got || expired)) {
+      warnlog("Inserting dynamic block for %s for %d seconds: %s", name, rule.d_blockDuration, rule.d_blockReason);
+    }
+    blocks.add(name, db);
+    updated = true;
+  }
+
   void addBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const ComboAddress& requestor, const DynBlockRule& rule, bool& updated)
   {
     addOrRefreshBlock(blocks, now, requestor, rule, updated, false);
@@ -368,6 +490,11 @@ private:
     return d_respRateRule.isEnabled() || !d_rcodeRules.empty();
   }
 
+  bool hasSuffixMatchRules() const
+  {
+    return d_suffixMatchRule.isEnabled();
+  }
+
   bool hasRules() const
   {
     return hasQueryRules() || hasResponseRules();
@@ -410,15 +537,18 @@ private:
     }
   }
 
-  void processResponseRules(counts_t& counts, const struct timespec& now)
+  void processResponseRules(counts_t& counts, StatNode& root, const struct timespec& now)
   {
-    if (!hasResponseRules()) {
+    if (!hasResponseRules() && !hasSuffixMatchRules()) {
       return;
     }
 
     d_respRateRule.d_cutOff = d_respRateRule.d_minTime = now;
     d_respRateRule.d_cutOff.tv_sec -= d_respRateRule.d_seconds;
 
+    d_suffixMatchRule.d_cutOff = d_suffixMatchRule.d_minTime = now;
+    d_suffixMatchRule.d_cutOff.tv_sec -= d_suffixMatchRule.d_seconds;
+
     for (auto& rule : d_rcodeRules) {
       rule.second.d_cutOff = rule.second.d_minTime = now;
       rule.second.d_cutOff.tv_sec -= rule.second.d_seconds;
@@ -432,6 +562,7 @@ private:
         }
 
         bool respRateMatches = d_respRateRule.matches(c.when);
+        bool suffixMatchRuleMatches = d_suffixMatchRule.matches(c.when);
         bool rcodeRuleMatches = checkIfResponseCodeMatches(c);
 
         if (respRateMatches || rcodeRuleMatches) {
@@ -443,6 +574,10 @@ private:
             entry.d_rcodeCounts[c.dh.rcode]++;
           }
         }
+
+        if (suffixMatchRuleMatches) {
+          root.submit(c.name, c.dh.rcode, boost::none);
+        }
       }
     }
   }
@@ -451,6 +586,10 @@ private:
   std::map<uint16_t, DynBlockRule> d_qtypeRules;
   DynBlockRule d_queryRateRule;
   DynBlockRule d_respRateRule;
+  DynBlockRule d_suffixMatchRule;
   NetmaskGroup d_excludedSubnets;
+  SuffixMatchNode d_excludedDomains;
+  smtVisitor_t d_smtVisitor;
+  dnsdist_ffi_stat_node_visitor_t d_smtVisitorFFI;
   bool d_beQuiet{false};
 };
index 56422e48a8a992ca5cceb60be62598c072924144..5e8974d6983c9d9fdcd47aac8c78ee1b1273cba5 100644 (file)
@@ -257,10 +257,10 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
-  dh.d_clen = htons((uint16_t) optRData.length());
+  dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
   res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
-  res.assign((const char *) &name, sizeof name);
-  res.append((const char *) &dh, sizeof dh);
+  res.assign(reinterpret_cast<const char *>(&name), sizeof name);
+  res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
   res.append(optRData.c_str(), optRData.length());
 }
 
@@ -464,9 +464,7 @@ static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16
 
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
 {
-  /* we need at least:
-     root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (*optLen < 11) {
+  if (*optLen < optRecordMinimumSize) {
     return EINVAL;
   }
   const unsigned char* end = (const unsigned char*) optStart + *optLen;
@@ -490,15 +488,13 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
 
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
 {
-  /* we need at least:
-   root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (optLen < 11) {
+  if (optLen < optRecordMinimumSize) {
     return false;
   }
   size_t p = optStart + 9;
   uint16_t rdLen = (0x100*packet.at(p) + packet.at(p+1));
   p += sizeof(rdLen);
-  if (rdLen > (optLen - 11)) {
+  if (rdLen > (optLen - optRecordMinimumSize)) {
     return false;
   }
 
@@ -741,3 +737,31 @@ bool queryHasEDNS(const DNSQuestion& dq)
 
   return false;
 }
+
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
+{
+  uint16_t optStart;
+  size_t optLen = 0;
+  bool last = false;
+  const char * packet = reinterpret_cast<const char*>(dq.dh);
+  std::string packetStr(packet, dq.len);
+  int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+  if (res != 0) {
+    // no EDNS OPT RR
+    return false;
+  }
+
+  if (optLen < optRecordMinimumSize) {
+    return false;
+  }
+
+  if (optStart < dq.len && packetStr.at(optStart) != 0) {
+    // OPT RR Name != '.'
+    return false;
+  }
+
+  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
+  memcpy(&edns0, packet + optStart + 5, sizeof edns0);
+  return true;
+}
index 7c3739f443c2546c46101c4935169c9207c48496..767575723f059ca4320dfb89c88a055fb37fb300 100644 (file)
@@ -21,6 +21,9 @@
  */
 #pragma once
 
+// root label (1), type (2), class (2), ttl (4) + rdlen (2)
+static const size_t optRecordMinimumSize = 11;
+
 extern size_t g_EdnsUDPPayloadSize;
 extern uint16_t g_PayloadSizeSelfGenAnswers;
 
@@ -42,3 +45,4 @@ bool parseEDNSOptions(DNSQuestion& dq);
 
 int getEDNSZ(const DNSQuestion& dq);
 bool queryHasEDNS(const DNSQuestion& dq);
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0);
index 55c7251e0aaca0866ec83c78e1ffc17cdf89c5bf..c9cadfc5ee01ea23fc2c3e8b77e8ce683c3a2f90 100644 (file)
@@ -19,6 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+#include "config.h"
 #include "threadname.hh"
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "ednsoptions.hh"
 #include "fstrm_logger.hh"
 #include "remote_logger.hh"
-#include "boost/optional/optional_io.hpp"
+
+#include <boost/optional/optional_io.hpp>
+
+#ifdef HAVE_LIBCRYPTO
+#include "ipcipher.hh"
+#endif /* HAVE_LIBCRYPTO */
 
 class DropAction : public DNSAction
 {
@@ -739,7 +745,7 @@ private:
 class DnstapLogAction : public DNSAction, public boost::noncopyable
 {
 public:
-  DnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
+  DnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
   {
   }
   DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
@@ -749,7 +755,7 @@ public:
     {
       if (d_alterFunc) {
         std::lock_guard<std::mutex> lock(g_luamutex);
-        (*d_alterFunc)(*dq, &message);
+        (*d_alterFunc)(dq, &message);
       }
     }
     std::string data;
@@ -765,13 +771,13 @@ public:
 private:
   std::string d_identity;
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > d_alterFunc;
 };
 
 class RemoteLogAction : public DNSAction, public boost::noncopyable
 {
 public:
-  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID)
+  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey)
   {
   }
   DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
@@ -786,9 +792,16 @@ public:
       message.setServerIdentity(d_serverID);
     }
 
+#if HAVE_LIBCRYPTO
+    if (!d_ipEncryptKey.empty())
+    {
+      message.setRequestor(encryptCA(*dq->remote, d_ipEncryptKey));
+    }
+#endif /* HAVE_LIBCRYPTO */
+
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      (*d_alterFunc)(*dq, &message);
+      (*d_alterFunc)(dq, &message);
     }
 
     std::string data;
@@ -803,8 +816,9 @@ public:
   }
 private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
+  std::string d_ipEncryptKey;
 };
 
 class SNMPTrapAction : public DNSAction
@@ -857,7 +871,7 @@ private:
 class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable
 {
 public:
-  DnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
+  DnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
   {
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
@@ -869,7 +883,7 @@ public:
     {
       if (d_alterFunc) {
         std::lock_guard<std::mutex> lock(g_luamutex);
-        (*d_alterFunc)(*dr, &message);
+        (*d_alterFunc)(dr, &message);
       }
     }
     std::string data;
@@ -885,13 +899,13 @@ public:
 private:
   std::string d_identity;
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > d_alterFunc;
 };
 
 class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
 {
 public:
-  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_includeCNAME(includeCNAME)
+  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME)
   {
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
@@ -906,9 +920,16 @@ public:
       message.setServerIdentity(d_serverID);
     }
 
+#if HAVE_LIBCRYPTO
+    if (!d_ipEncryptKey.empty())
+    {
+      message.setRequestor(encryptCA(*dr->remote, d_ipEncryptKey));
+    }
+#endif /* HAVE_LIBCRYPTO */
+
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      (*d_alterFunc)(*dr, &message);
+      (*d_alterFunc)(dr, &message);
     }
 
     std::string data;
@@ -923,8 +944,9 @@ public:
   }
 private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
+  std::string d_ipEncryptKey;
   bool d_includeCNAME;
 };
 
@@ -1053,14 +1075,6 @@ void setupLuaActions()
       addAction(&g_rulactions, var, boost::get<std::shared_ptr<DNSAction> >(era), params);
     });
 
-  g_lua.writeFunction("addLuaAction", [](luadnsrule_t var, LuaAction::func_t func, boost::optional<luaruleparams_t> params) {
-      addAction(&g_rulactions, var, std::make_shared<LuaAction>(func), params);
-    });
-
-  g_lua.writeFunction("addLuaResponseAction", [](luadnsrule_t var, LuaResponseAction::func_t func, boost::optional<luaruleparams_t> params) {
-      addAction(&g_resprulactions, var, std::make_shared<LuaResponseAction>(func), params);
-    });
-
   g_lua.writeFunction("addResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) {
       if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
         throw std::runtime_error("addResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
@@ -1212,7 +1226,7 @@ void setupLuaActions()
       return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(func));
     });
 
-  g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
+  g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       // avoids potentially-evaluated-expression warning with clang.
       RemoteLoggerInterface& rl = *logger.get();
       if (typeid(rl) != typeid(RemoteLogger)) {
@@ -1221,20 +1235,24 @@ void setupLuaActions()
       }
 
       std::string serverID;
+      std::string ipEncryptKey;
       if (vars) {
         if (vars->count("serverID")) {
           serverID = boost::get<std::string>((*vars)["serverID"]);
         }
+        if (vars->count("ipEncryptKey")) {
+          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
+        }
       }
 
 #ifdef HAVE_PROTOBUF
-      return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID));
+      return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID, ipEncryptKey));
 #else
       throw std::runtime_error("Protobuf support is required to use RemoteLogAction");
 #endif
     });
 
-  g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<bool> includeCNAME, boost::optional<std::unordered_map<std::string, std::string>> vars) {
+  g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<bool> includeCNAME, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       // avoids potentially-evaluated-expression warning with clang.
       RemoteLoggerInterface& rl = *logger.get();
       if (typeid(rl) != typeid(RemoteLogger)) {
@@ -1243,20 +1261,24 @@ void setupLuaActions()
       }
 
       std::string serverID;
+      std::string ipEncryptKey;
       if (vars) {
         if (vars->count("serverID")) {
           serverID = boost::get<std::string>((*vars)["serverID"]);
         }
+        if (vars->count("ipEncryptKey")) {
+          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
+        }
       }
 
 #ifdef HAVE_PROTOBUF
-      return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, includeCNAME ? *includeCNAME : false));
+      return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, ipEncryptKey, includeCNAME ? *includeCNAME : false));
 #else
       throw std::runtime_error("Protobuf support is required to use RemoteLogResponseAction");
 #endif
     });
 
-  g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > alterFunc) {
+  g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc) {
 #ifdef HAVE_PROTOBUF
       return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, alterFunc));
 #else
@@ -1264,7 +1286,7 @@ void setupLuaActions()
 #endif
     });
 
-  g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > alterFunc) {
+  g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc) {
 #ifdef HAVE_PROTOBUF
       return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, alterFunc));
 #else
index 79e01d2e13792acb0162eab0757f63caea9d4115..b29fc3edf3b4eaa20c5b95d1e3d0d21d69d8eee5 100644 (file)
@@ -23,6 +23,7 @@
 #include <sys/stat.h>
 #include <sys/types.h>
 
+#include "config.h"
 #include "dnsdist.hh"
 #include "dnsdist-lua.hh"
 #include "dnsdist-protobuf.hh"
 #include "fstrm_logger.hh"
 #include "remote_logger.hh"
 
+#ifdef HAVE_LIBCRYPTO
+#include "ipcipher.hh"
+#endif /* HAVE_LIBCRYPTO */
+
 void setupLuaBindings(bool client)
 {
   g_lua.writeFunction("infolog", [](const string& arg) {
@@ -166,6 +171,19 @@ void setupLuaBindings(bool client)
   g_lua.registerFunction<ComboAddress(ComboAddress::*)()>("mapToIPv4", [](const ComboAddress& ca) { return ca.mapToIPv4(); });
   g_lua.registerFunction<bool(nmts_t::*)(const ComboAddress&)>("match", [](nmts_t& s, const ComboAddress& ca) { return s.match(ca); });
 
+#ifdef HAVE_LIBCRYPTO
+  g_lua.registerFunction<ComboAddress(ComboAddress::*)(const std::string& key)>("ipencrypt", [](const ComboAddress& ca, const std::string& key) {
+      return encryptCA(ca, key);
+    });
+  g_lua.registerFunction<ComboAddress(ComboAddress::*)(const std::string& key)>("ipdecrypt", [](const ComboAddress& ca, const std::string& key) {
+      return decryptCA(ca, key);
+    });
+
+  g_lua.writeFunction("makeIPCipherKey", [](const std::string& password) {
+      return makeIPCipherKey(password);
+    });
+#endif /* HAVE_LIBCRYPTO */
+
   /* DNSName */
   g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
   g_lua.registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });
@@ -230,9 +248,18 @@ void setupLuaBindings(bool client)
 #endif /* HAVE_EBPF */
 
   /* PacketCache */
-  g_lua.writeFunction("newPacketCache", [](size_t maxEntries, boost::optional<uint32_t> maxTTL, boost::optional<uint32_t> minTTL, boost::optional<uint32_t> tempFailTTL, boost::optional<uint32_t> staleTTL, boost::optional<bool> dontAge, boost::optional<size_t> numberOfShards, boost::optional<bool> deferrableInsertLock, boost::optional<uint32_t> maxNegativeTTL, boost::optional<bool> ecsParsing, boost::optional<std::unordered_map<std::string, boost::variant<bool, size_t>>> vars) {
+  g_lua.writeFunction("newPacketCache", [](size_t maxEntries, boost::optional<std::unordered_map<std::string, boost::variant<bool, size_t>>> vars) {
 
       bool keepStaleData = false;
+      size_t maxTTL = 86400;
+      size_t minTTL = 0;
+      size_t tempFailTTL = 60;
+      size_t maxNegativeTTL = 3600;
+      size_t staleTTL = 60;
+      size_t numberOfShards = 1;
+      bool dontAge = false;
+      bool deferrableInsertLock = true;
+      bool ecsParsing = false;
 
       if (vars) {
 
@@ -279,10 +306,9 @@ void setupLuaBindings(bool client)
         if (vars->count("temporaryFailureTTL")) {
           tempFailTTL = boost::get<size_t>((*vars)["temporaryFailureTTL"]);
         }
-
       }
 
-      auto res = std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL ? *maxTTL : 86400, minTTL ? *minTTL : 0, tempFailTTL ? *tempFailTTL : 60, maxNegativeTTL ? *maxNegativeTTL : 3600, staleTTL ? *staleTTL : 60, dontAge ? *dontAge : false, numberOfShards ? *numberOfShards : 1, deferrableInsertLock ? *deferrableInsertLock : true, ecsParsing ? *ecsParsing : false);
+      auto res = std::make_shared<DNSDistPacketCache>(maxEntries, maxTTL, minTTL, tempFailTTL, maxNegativeTTL, staleTTL, dontAge, numberOfShards, deferrableInsertLock, ecsParsing);
 
       res->setKeepStaleData(keepStaleData);
 
@@ -659,10 +685,10 @@ void setupLuaBindings(bool client)
   g_lua.registerFunction<size_t(EDNSOptionView::*)()>("count", [](const EDNSOptionView& option) {
       return option.values.size();
     });
-  g_lua.registerFunction<std::vector<std::pair<int, string>>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
-    std::vector<std::pair<int, string> > values;
+  g_lua.registerFunction<std::vector<string>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
+    std::vector<string> values;
     for (const auto& value : option.values) {
-      values.push_back(std::make_pair(values.size(), std::string(value.content, value.size)));
+      values.push_back(std::string(value.content, value.size));
     }
     return values;
   });
index c64bf563f1bf4c5dab3ccace2b30b09d7acc1e2f..25d9a187d3acd45da71d1dea8e67a797e44bfe26 100644 (file)
@@ -552,10 +552,38 @@ void setupLuaInspection()
 
   g_lua.writeFunction("showTCPStats", [] {
       setLuaNoSideEffect();
+      ostringstream ret;
       boost::format fmt("%-10d %-10d %-10d %-10d\n");
-      g_outputBuffer += (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued").str();
-      g_outputBuffer += (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections).str();
-      g_outputBuffer += "Query distribution mode is: " + std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") + "\n";
+      ret << (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued") << endl;
+      ret << (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections) << endl;
+      ret <<endl;
+
+      ret << "Query distribution mode is: " << std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") << endl;
+      ret << endl;
+
+      ret << "Frontends:" << endl;
+      fmt = boost::format("%-3d %-20.20s %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f");
+      ret << (fmt % "#" % "Address" % "Connnections" % "Died reading query" % "Died sending response" % "Gave up" % "Client timeouts" % "Downstream timeouts" % "Avg queries/conn" % "Avg duration") << endl;
+
+      size_t counter = 0;
+      for(const auto& f : g_frontends) {
+        ret << (fmt % counter % f->local.toStringWithPort() % f->tcpCurrentConnections % f->tcpDiedReadingQuery % f->tcpDiedSendingResponse % f->tcpGaveUp % f->tcpClientTimeouts % f->tcpDownstreamTimeouts % f->tcpAvgQueriesPerConnection % f->tcpAvgConnectionDuration) << endl;
+        ++counter;
+      }
+      ret << endl;
+
+      ret << "Backends:" << endl;
+      fmt = boost::format("%-3d %-20.20s %-20.20s %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f");
+      ret << (fmt % "#" % "Name" % "Address" % "Connections" % "Died sending query" % "Died reading response" % "Gave up" % "Read timeouts" % "Write timeouts" % "Avg queries/conn" % "Avg duration") << endl;
+
+      auto states = g_dstates.getLocal();
+      counter = 0;
+      for(const auto& s : *states) {
+        ret << (fmt % counter % s->name % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
+        ++counter;
+      }
+
+      g_outputBuffer=ret.str();
     });
 
   g_lua.writeFunction("dumpStats", [] {
@@ -660,6 +688,16 @@ void setupLuaInspection()
         group->setResponseByteRate(rate, warningRate ? *warningRate : 0, seconds, reason, blockDuration, action ? *action : DNSAction::Action::None);
       }
     });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, DynBlockRulesGroup::smtVisitor_t)>("setSuffixMatchRule", [](std::shared_ptr<DynBlockRulesGroup>& group, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, DynBlockRulesGroup::smtVisitor_t visitor) {
+      if (group) {
+        group->setSuffixMatchRule(seconds, reason, blockDuration, action ? *action : DNSAction::Action::None, visitor);
+      }
+    });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, dnsdist_ffi_stat_node_visitor_t)>("setSuffixMatchRuleFFI", [](std::shared_ptr<DynBlockRulesGroup>& group, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, dnsdist_ffi_stat_node_visitor_t visitor) {
+      if (group) {
+        group->setSuffixMatchRuleFFI(seconds, reason, blockDuration, action ? *action : DNSAction::Action::None, visitor);
+      }
+    });
   g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(uint8_t, unsigned int, unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, boost::optional<unsigned int>)>("setRCodeRate", [](std::shared_ptr<DynBlockRulesGroup>& group, uint8_t rcode, unsigned int rate, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, boost::optional<unsigned int> warningRate) {
       if (group) {
         group->setRCodeRate(rcode, rate, warningRate ? *warningRate : 0, seconds, reason, blockDuration, action ? *action : DNSAction::Action::None);
@@ -690,6 +728,16 @@ void setupLuaInspection()
         group->includeRange(Netmask(*boost::get<std::string>(&ranges)));
       }
     });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(boost::variant<std::string, std::vector<std::pair<int, std::string>>>)>("excludeDomains", [](std::shared_ptr<DynBlockRulesGroup>& group, boost::variant<std::string, std::vector<std::pair<int, std::string>>> domains) {
+      if (domains.type() == typeid(std::vector<std::pair<int, std::string>>)) {
+        for (const auto& range : *boost::get<std::vector<std::pair<int, std::string>>>(&domains)) {
+          group->excludeDomain(DNSName(range.second));
+        }
+      }
+      else {
+        group->excludeDomain(DNSName(*boost::get<std::string>(&domains)));
+      }
+    });
   g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)()>("apply", [](std::shared_ptr<DynBlockRulesGroup>& group) {
     group->apply();
   });
index d4d199436e1e78634e290a21d63fc4b26b9544bf..b912cda51b4ef6028a3e3cd8cd02e479841035eb 100644 (file)
@@ -22,6 +22,8 @@
 #include "dnsdist.hh"
 #include "ednsoptions.hh"
 
+#undef BADSIG  // signal.h SIG_ERR
+
 void setupLuaVars()
 {
   g_lua.writeVariable("DNSAction", std::unordered_map<string,int>{
@@ -85,33 +87,54 @@ void setupLuaVars()
       {"KEYTAG",       EDNSOptionCode::KEYTAG }
     });
 
-  vector<pair<string, int> > rcodes = {{"NOERROR",  RCode::NoError  },
-                                       {"FORMERR",  RCode::FormErr  },
-                                       {"SERVFAIL", RCode::ServFail },
-                                       {"NXDOMAIN", RCode::NXDomain },
-                                       {"NOTIMP",   RCode::NotImp   },
-                                       {"REFUSED",  RCode::Refused  },
-                                       {"YXDOMAIN", RCode::YXDomain },
-                                       {"YXRRSET",  RCode::YXRRSet  },
-                                       {"NXRRSET",  RCode::NXRRSet  },
-                                       {"NOTAUTH",  RCode::NotAuth  },
-                                       {"NOTZONE",  RCode::NotZone  },
-                                       {"BADVERS",  ERCode::BADVERS },
-                                       {"BADSIG",   ERCode::BADSIG  },
-                                       {"BADKEY",   ERCode::BADKEY  },
-                                       {"BADTIME",  ERCode::BADTIME   },
-                                       {"BADMODE",  ERCode::BADMODE   },
-                                       {"BADNAME",  ERCode::BADNAME   },
-                                       {"BADALG",   ERCode::BADALG    },
-                                       {"BADTRUNC", ERCode::BADTRUNC  },
-                                       {"BADCOOKIE",ERCode::BADCOOKIE },
-  };
+  g_lua.writeVariable("DNSRCode", std::unordered_map<string, int>{
+      {"NOERROR",  RCode::NoError  },
+      {"FORMERR",  RCode::FormErr  },
+      {"SERVFAIL", RCode::ServFail },
+      {"NXDOMAIN", RCode::NXDomain },
+      {"NOTIMP",   RCode::NotImp   },
+      {"REFUSED",  RCode::Refused  },
+      {"YXDOMAIN", RCode::YXDomain },
+      {"YXRRSET",  RCode::YXRRSet  },
+      {"NXRRSET",  RCode::NXRRSet  },
+      {"NOTAUTH",  RCode::NotAuth  },
+      {"NOTZONE",  RCode::NotZone  },
+      {"BADVERS",  ERCode::BADVERS },
+      {"BADSIG",   ERCode::BADSIG  },
+      {"BADKEY",   ERCode::BADKEY  },
+      {"BADTIME",  ERCode::BADTIME   },
+      {"BADMODE",  ERCode::BADMODE   },
+      {"BADNAME",  ERCode::BADNAME   },
+      {"BADALG",   ERCode::BADALG    },
+      {"BADTRUNC", ERCode::BADTRUNC  },
+      {"BADCOOKIE",ERCode::BADCOOKIE }
+  });
+
   vector<pair<string, int> > dd;
   for(const auto& n : QType::names)
     dd.push_back({n.first, n.second});
-  for(const auto& n : rcodes)
-    dd.push_back({n.first, n.second});
-  g_lua.writeVariable("dnsdist", dd);
+  g_lua.writeVariable("DNSQType", dd);
+
+  g_lua.executeCode(R"LUA(
+    local tables = {
+      DNSQType = DNSQType,
+      DNSRCode = DNSRCode
+    }
+    local function index (table, key)
+      for tname,t in pairs(tables)
+      do
+        local val = t[key]
+        if val then
+          warnlog(string.format("access to dnsdist.%s is deprecated, please use %s.%s", key, tname, key))
+          return val
+        end
+      end
+    end
+
+    dnsdist = {}
+    setmetatable(dnsdist, { __index = index })
+    )LUA"
+  );
 
 #ifdef HAVE_DNSCRYPT
     g_lua.writeVariable("DNSCryptExchangeVersion", std::unordered_map<string,int>{
index c7e03da7b3bdafb2f965ba23744860a83e0a2687..ad4c55472354e2a6f47df8652d8bf0027a534fc3 100644 (file)
@@ -497,10 +497,23 @@ void setupLuaConfig(bool client)
 
       try {
        ComboAddress loc(addr, 53);
-       g_locals.clear();
-       g_locals.push_back(std::make_tuple(loc, doTCP, reusePort, tcpFastOpenQueueSize, interface, cpus)); /// only works pre-startup, so no sync necessary
+        for (auto it = g_frontends.begin(); it != g_frontends.end(); ) {
+          /* TLS and DNSCrypt frontends are separate */
+          if ((*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr) {
+            it = g_frontends.erase(it);
+          }
+          else {
+            ++it;
+          }
+        }
+
+        // only works pre-startup, so no sync necessary
+        g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus)));
+        if (doTCP) {
+          g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus)));
+        }
       }
-      catch(std::exception& e) {
+      catch(const std::exception& e) {
        g_outputBuffer="Error: "+string(e.what())+"\n";
       }
     });
@@ -523,7 +536,11 @@ void setupLuaConfig(bool client)
 
       try {
        ComboAddress loc(addr, 53);
-       g_locals.push_back(std::make_tuple(loc, doTCP, reusePort, tcpFastOpenQueueSize, interface, cpus)); /// only works pre-startup, so no sync necessary
+        // only works pre-startup, so no sync necessary
+        g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus)));
+        if (doTCP) {
+          g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus)));
+        }
       }
       catch(std::exception& e) {
        g_outputBuffer="Error: "+string(e.what())+"\n";
@@ -1095,7 +1112,17 @@ void setupLuaConfig(bool client)
 
       try {
         auto ctx = std::make_shared<DNSCryptContext>(providerName, certFile, keyFile);
-        g_dnsCryptLocals.push_back(std::make_tuple(ComboAddress(addr, 443), ctx, reusePort, tcpFastOpenQueueSize, interface, cpus));
+
+        /* UDP */
+        auto cs = std::unique_ptr<ClientState>(new ClientState(ComboAddress(addr, 443), false, reusePort, tcpFastOpenQueueSize, interface, cpus));
+        cs->dnscryptCtx = ctx;
+        g_dnsCryptLocals.push_back(ctx);
+        g_frontends.push_back(std::move(cs));
+
+        /* TCP */
+        cs = std::unique_ptr<ClientState>(new ClientState(ComboAddress(addr, 443), true, reusePort, tcpFastOpenQueueSize, interface, cpus));
+        cs->dnscryptCtx = ctx;
+        g_frontends.push_back(std::move(cs));
       }
       catch(std::exception& e) {
         errlog(e.what());
@@ -1115,9 +1142,9 @@ void setupLuaConfig(bool client)
       ret << (fmt % "#" % "Address" % "Provider Name") << endl;
       size_t idx = 0;
 
-      for (const auto& local : g_dnsCryptLocals) {
-        const std::shared_ptr<DNSCryptContext> ctx = std::get<1>(local);
-        ret<< (fmt % idx % std::get<0>(local).toStringWithPort() % ctx->getProviderName()) << endl;
+      for (const auto& frontend : g_frontends) {
+        const std::shared_ptr<DNSCryptContext> ctx = frontend->dnscryptCtx;
+        ret<< (fmt % idx % frontend->local.toStringWithPort() % ctx->getProviderName()) << endl;
         idx++;
       }
 
@@ -1132,7 +1159,7 @@ void setupLuaConfig(bool client)
 #ifdef HAVE_DNSCRYPT
       std::shared_ptr<DNSCryptContext> ret = nullptr;
       if (idx < g_dnsCryptLocals.size()) {
-        ret = std::get<1>(g_dnsCryptLocals.at(idx));
+        ret = g_dnsCryptLocals.at(idx);
       }
       return ret;
 #else
@@ -1268,13 +1295,13 @@ void setupLuaConfig(bool client)
       setLuaNoSideEffect();
       try {
         ostringstream ret;
-        boost::format fmt("%1$-3d %2$-20.20s %|25t|%3$-8.8s %|35t|%4%" );
+        boost::format fmt("%1$-3d %2$-20.20s %|35t|%3$-20.20s %|57t|%4%" );
         //             1    2           3            4
         ret << (fmt % "#" % "Address" % "Protocol" % "Queries" ) << endl;
 
         size_t counter = 0;
         for (const auto& front : g_frontends) {
-          ret << (fmt % counter % front->local.toStringWithPort() % (front->udpFD != -1 ? "UDP" : "TCP") % front->queries) << endl;
+          ret << (fmt % counter % front->local.toStringWithPort() % front->getType() % front->queries) << endl;
           counter++;
         }
         g_outputBuffer=ret.str();
@@ -1285,7 +1312,7 @@ void setupLuaConfig(bool client)
       setLuaNoSideEffect();
       ClientState* ret = nullptr;
       if(num < g_frontends.size()) {
-        ret=g_frontends[num];
+        ret=g_frontends[num].get();
       }
       return ret;
       });
@@ -1458,6 +1485,11 @@ void setupLuaConfig(bool client)
       g_servFailOnNoPolicy = servfail;
     });
 
+  g_lua.writeFunction("setRoundRobinFailOnNoServer", [](bool fail) {
+      setLuaSideEffect();
+      g_roundrobinFailOnNoServer = fail;
+    });
+
   g_lua.writeFunction("setRingBuffersSize", [](size_t capacity, boost::optional<size_t> numberOfShards) {
       setLuaSideEffect();
       if (g_configurationDone) {
@@ -1609,6 +1641,15 @@ void setupLuaConfig(bool client)
       g_secPollInterval = newInterval;
   });
 
+  g_lua.writeFunction("setSyslogFacility", [](int facility) {
+    setLuaSideEffect();
+    if (g_configurationDone) {
+      g_outputBuffer="setSyslogFacility cannot be used at runtime!\n";
+      return;
+    }
+    setSyslogFacility(facility);
+  });
+
   g_lua.writeFunction("addTLSLocal", [client](const std::string& addr, boost::variant<std::string, std::vector<std::pair<int,std::string>>> certFiles, boost::variant<std::string, std::vector<std::pair<int,std::string>>> keyFiles, boost::optional<localbind_t> vars) {
         if (client)
           return;
@@ -1624,9 +1665,16 @@ void setupLuaConfig(bool client)
           return;
         }
 
+        bool doTCP = true;
+        bool reusePort = false;
+        int tcpFastOpenQueueSize = 0;
+        std::string interface;
+        std::set<int> cpus;
+        (void) doTCP;
+
         if (vars) {
           bool doTCP = true;
-          parseLocalBindVars(vars, doTCP, frontend->d_reusePort, frontend->d_tcpFastOpenQueueSize, frontend->d_interface, frontend->d_cpus);
+          parseLocalBindVars(vars, doTCP, reusePort, tcpFastOpenQueueSize, interface, cpus);
 
           if (vars->count("provider")) {
             frontend->d_provider = boost::get<const string>((*vars)["provider"]);
@@ -1666,7 +1714,11 @@ void setupLuaConfig(bool client)
         try {
           frontend->d_addr = ComboAddress(addr, 853);
           vinfolog("Loading TLS provider %s", frontend->d_provider);
-          g_tlslocals.push_back(frontend); /// only works pre-startup, so no sync necessary
+          // only works pre-startup, so no sync necessary
+          auto cs = std::unique_ptr<ClientState>(new ClientState(frontend->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus));
+          cs->tlsFrontend = frontend;
+          g_tlslocals.push_back(cs->tlsFrontend);
+          g_frontends.push_back(std::move(cs));
         }
         catch(const std::exception& e) {
           g_outputBuffer="Error: "+string(e.what())+"\n";
@@ -1766,7 +1818,30 @@ void setupLuaConfig(bool client)
 #endif
       });
 
-  g_lua.writeFunction("setAllowEmptyResponse", [](bool allow) { g_allowEmptyResponse=allow; });
+    g_lua.writeFunction("reloadAllCertificates", []() {
+        for (auto& frontend : g_frontends) {
+          if (!frontend) {
+            continue;
+          }
+          try {
+#ifdef HAVE_DNSCRYPT
+            if (frontend->dnscryptCtx) {
+              frontend->dnscryptCtx->reloadCertificate();
+            }
+#endif /* HAVE_DNSCRYPT */
+#ifdef HAVE_DNS_OVER_TLS
+            if (frontend->tlsFrontend) {
+              frontend->tlsFrontend->setupTLS();
+            }
+#endif /* HAVE_DNS_OVER_TLS */
+          }
+          catch(const std::exception& e) {
+            errlog("Error reloading certificates for frontend %s: %s", frontend->local.toStringWithPort(), e.what());
+          }
+        }
+      });
+
+    g_lua.writeFunction("setAllowEmptyResponse", [](bool allow) { g_allowEmptyResponse=allow; });
 }
 
 vector<std::function<void(void)>> setupLua(bool client, const std::string& config)
index 777865ff162b5f130234964f0f21f6beadc58984..1c51cdc16cab584b5df2f00c31c1262a877c61d0 100644 (file)
@@ -35,6 +35,8 @@
 #include <atomic>
 #include <netinet/tcp.h>
 
+#include "sstuff.hh"
+
 using std::thread;
 using std::atomic;
 
@@ -53,49 +55,165 @@ using std::atomic;
    Let's start naively.
 */
 
-static int setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures)
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+static const size_t g_maxCachedConnectionsPerDownstream = 20;
+uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+uint16_t g_downstreamTCPCleanupInterval{60};
+bool g_useTCPSinglePipe{false};
+
+static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
 {
+  std::unique_ptr<Socket> result;
+
   do {
     vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
-    int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+    result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
     try {
       if (!IsAnyAddress(ds->sourceAddr)) {
-        SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
+        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef IP_BIND_ADDRESS_NO_PORT
         if (ds->ipBindAddrNoPort) {
-          SSetsockopt(sock, SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
         }
 #endif
-        SBind(sock, ds->sourceAddr);
+        result->bind(ds->sourceAddr, false);
       }
-      setNonBlocking(sock);
+      result->setNonBlocking();
 #ifdef MSG_FASTOPEN
       if (!ds->tcpFastOpen) {
-        SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+        SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
       }
 #else
-      SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+      SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
 #endif /* MSG_FASTOPEN */
-      return sock;
+      return result;
     }
     catch(const std::runtime_error& e) {
-      /* don't leak our file descriptor if SConnect() (for example) throws */
+      vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
       downstreamFailures++;
-      close(sock);
       if (downstreamFailures > ds->retries) {
         throw;
       }
     }
   } while(downstreamFailures <= ds->retries);
 
-  return -1;
+  return nullptr;
+}
+
+class TCPConnectionToBackend
+{
+public:
+  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now)
+  {
+    d_socket = setupTCPDownstream(d_ds, downstreamFailures);
+    ++d_ds->tcpCurrentConnections;
+  }
+
+  ~TCPConnectionToBackend()
+  {
+    if (d_ds && d_socket) {
+      --d_ds->tcpCurrentConnections;
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      auto diff = now - d_connectionStartTime;
+      d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
+    }
+  }
+
+  int getHandle() const
+  {
+    if (!d_socket) {
+      throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
+    }
+
+    return d_socket->getHandle();
+  }
+
+  const ComboAddress& getRemote() const
+  {
+    return d_ds->remote;
+  }
+
+  bool isFresh() const
+  {
+    return d_fresh;
+  }
+
+  void incQueries()
+  {
+    ++d_queries;
+  }
+
+  void setReused()
+  {
+    d_fresh = false;
+  }
+
+private:
+  std::unique_ptr<Socket> d_socket{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  struct timeval d_connectionStartTime;
+  uint64_t d_queries{0};
+  bool d_fresh{true};
+};
+
+static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+
+static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
+{
+  std::unique_ptr<TCPConnectionToBackend> result;
+
+  const auto& it = t_downstreamConnections.find(ds->remote);
+  if (it != t_downstreamConnections.end()) {
+    auto& list = it->second;
+    if (!list.empty()) {
+      result = std::move(list.front());
+      list.pop_front();
+      result->setReused();
+      return result;
+    }
+  }
+
+  return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
+}
+
+static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
+{
+  if (conn == nullptr) {
+    return;
+  }
+
+  const auto& remote = conn->getRemote();
+  const auto& it = t_downstreamConnections.find(remote);
+  if (it != t_downstreamConnections.end()) {
+    auto& list = it->second;
+    if (list.size() >= g_maxCachedConnectionsPerDownstream) {
+      /* too many connections queued already */
+      conn.reset();
+      return;
+    }
+    list.push_back(std::move(conn));
+  }
+  else {
+    t_downstreamConnections[remote].push_back(std::move(conn));
+  }
 }
 
 struct ConnectionInfo
 {
-  ConnectionInfo(): cs(nullptr), fd(-1)
+  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
   {
   }
+  ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
+  {
+    rhs.cs = nullptr;
+    rhs.fd = -1;
+  }
 
   ConnectionInfo(const ConnectionInfo& rhs) = delete;
   ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
@@ -116,6 +234,9 @@ struct ConnectionInfo
       close(fd);
       fd = -1;
     }
+    if (cs) {
+      --cs->tcpCurrentConnections;
+    }
   }
 
   ComboAddress remote;
@@ -123,15 +244,6 @@ struct ConnectionInfo
   int fd{-1};
 };
 
-uint64_t g_maxTCPQueuedConnections{1000};
-size_t g_maxTCPQueriesPerConn{0};
-size_t g_maxTCPConnectionDuration{0};
-size_t g_maxTCPConnectionsPerClient{0};
-static std::mutex tcpClientsCountMutex;
-static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-bool g_useTCPSinglePipe{false};
-std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
-
 void tcpClientThread(int pipefd);
 
 static void decrementTCPClientCount(const ComboAddress& client)
@@ -161,6 +273,13 @@ void TCPClientCollection::addTCPClientThread()
       return;
     }
 
+    if (!setNonBlocking(pipefds[0])) {
+      close(pipefds[0]);
+      close(pipefds[1]);
+      errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
+      return;
+    }
+
     if (!setNonBlocking(pipefds[1])) {
       close(pipefds[0]);
       close(pipefds[1]);
@@ -201,544 +320,914 @@ void TCPClientCollection::addTCPClientThread()
   ++d_numthreads;
 }
 
-static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
-try
-{
-  uint16_t raw;
-  size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
-}
-
-static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len)
-try
+static void cleanupClosedTCPConnections()
 {
-  uint16_t raw;
-  size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
-}
+  for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
+    for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
+      if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
+        ++connIt;
+      }
+      else {
+        connIt = dsIt->second.erase(connIt);
+      }
+    }
 
-static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
-{
-  if (maxConnectionDuration) {
-    time_t curtime = time(nullptr);
-    unsigned int elapsed = 0;
-    if (curtime > start) { // To prevent issues when time goes backward
-      elapsed = curtime - start;
+    if (!dsIt->second.empty()) {
+      ++dsIt;
     }
-    if (elapsed >= maxConnectionDuration) {
-      return true;
+    else {
+      dsIt = t_downstreamConnections.erase(dsIt);
     }
-    remainingTime = maxConnectionDuration - elapsed;
   }
-  return false;
 }
 
-void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
+/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+   Updates pos everytime a successful read occurs,
+   throws an std::runtime_error in case of IO error,
+   return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+   would block.
+*/
+// XXX could probably be implemented as a TCPIOHandler
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
 {
-  for(auto it = sockets.begin(); it != sockets.end(); ) {
-    if (isTCPSocketUsable(it->second)) {
-      ++it;
+  if (buffer.size() < (pos + toRead)) {
+    throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
+  }
+
+  size_t got = 0;
+  do {
+    ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+    if (res == 0) {
+      throw runtime_error("EOF while reading message");
     }
-    else {
-      close(it->second);
-      it = sockets.erase(it);
+    if (res < 0) {
+      if (errno == EAGAIN || errno == EWOULDBLOCK) {
+        return IOState::NeedRead;
+      }
+      else {
+        throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+      }
     }
+
+    pos += static_cast<size_t>(res);
+    got += static_cast<size_t>(res);
   }
+  while (got < toRead);
+
+  return IOState::Done;
 }
 
-std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
 
-void tcpClientThread(int pipefd)
+class TCPClientThreadData
 {
-  /* we get launched with a pipe on which we receive file descriptors from clients that we own
-     from that point on */
-
-  setThreadName("dnsdist/tcpClie");
+public:
+  TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  {
+  }
 
-  bool outstanding = false;
-  time_t lastTCPCleanup = time(nullptr);
-  
   LocalHolders holders;
-  auto localRespRulactions = g_resprulactions.getLocal();
-#ifdef HAVE_DNSCRYPT
-  /* when the answer is encrypted in place, we need to get a copy
-     of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
-#endif
+  LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
+  std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+};
 
-  map<ComboAddress,int> sockets;
-  for(;;) {
-    ConnectionInfo* citmp, ci;
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+class IncomingTCPConnectionState
+{
+public:
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
+  {
+    d_ids.origDest.reset();
+    d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+    socklen_t socklen = d_ids.origDest.getSocklen();
+    if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
+      d_ids.origDest = d_ci.cs->local;
+    }
+  }
+
+  IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+  IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+  ~IncomingTCPConnectionState()
+  {
+    decrementTCPClientCount(d_ci.remote);
+    if (d_ci.cs != nullptr) {
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      auto diff = now - d_connectionStartTime;
+      d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
+    }
+
+    if (d_ds != nullptr) {
+      if (d_outstanding) {
+        --d_ds->outstanding;
+        d_outstanding = false;
+      }
+
+      if (d_downstreamConnection) {
+        try {
+          if (d_lastIOState == IOState::NeedRead) {
+            cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
+            d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
+          }
+          else if (d_lastIOState == IOState::NeedWrite) {
+            cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
+            d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
+          }
+        }
+        catch(const FDMultiplexerException& e) {
+          vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+        }
+      }
+    }
 
     try {
-      readn2(pipefd, &citmp, sizeof(citmp));
+      if (d_lastIOState == IOState::NeedRead) {
+        cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeReadFD(d_ci.fd);
+      }
+      else if (d_lastIOState == IOState::NeedWrite) {
+        cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeWriteFD(d_ci.fd);
+      }
     }
-    catch(const std::runtime_error& e) {
-      throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
+    catch(const FDMultiplexerException& e) {
+      vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
     }
+  }
 
-    g_tcpclientthreads->decrementQueuedCount();
-    ci=std::move(*citmp);
-    delete citmp;
+  void resetForNewQuery()
+  {
+    d_buffer.resize(sizeof(uint16_t));
+    d_currentPos = 0;
+    d_querySize = 0;
+    d_responseSize = 0;
+    d_downstreamFailures = 0;
+    d_state = State::readingQuerySize;
+    d_lastIOState = IOState::Done;
+  }
 
-    uint16_t qlen, rlen;
-    vector<uint8_t> rewrittenResponse;
-    shared_ptr<DownstreamState> ds;
-    ComboAddress dest;
-    dest.reset();
-    dest.sin4.sin_family = ci.remote.sin4.sin_family;
-    socklen_t len = dest.getSocklen();
-    size_t queriesCount = 0;
-    time_t connectionStartTime = time(NULL);
-    std::vector<char> queryBuffer;
-    std::vector<char> answerBuffer;
+  boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+      return boost::none;
+    }
 
-    if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
-      dest = ci.cs->local;
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+        return now;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+        now.tv_sec += remaining;
+        return now;
+      }
     }
 
-    try {
-      TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime);
+    now.tv_sec += g_tcpRecvTimeout;
+    return now;
+  }
 
-      for(;;) {
-        unsigned int remainingTime = 0;
-        ds = nullptr;
-        outstanding = false;
+  boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() without any backend selected");
+    }
+    if (d_ds->tcpRecvTimeout == 0) {
+      return boost::none;
+    }
 
-        if(!getNonBlockingMsgLenFromClient(handler, &qlen)) {
-          break;
-        }
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpRecvTimeout;
 
-        queriesCount++;
+    return res;
+  }
 
-        if (qlen < sizeof(dnsheader)) {
-          ++g_stats.nonCompliantQueries;
-          break;
-        }
+  boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+      return boost::none;
+    }
 
-        ci.cs->queries++;
-        ++g_stats.queries;
+    struct timeval res = now;
 
-        if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
-          break;
-        }
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+        return res;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+        res.tv_sec += remaining;
+        return res;
+      }
+    }
 
-        if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
-          break;
-        }
+    res.tv_sec += g_tcpSendTimeout;
+    return res;
+  }
 
-        bool ednsAdded = false;
-        bool ecsAdded = false;
-        /* allocate a bit more memory to be able to spoof the content,
-           or to add ECS without allocating a new buffer */
-        queryBuffer.resize(qlen + 512);
-
-        char* query = &queryBuffer[0];
-        handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
-
-        /* we need this one to be accurate ("real") for the protobuf message */
-       struct timespec queryRealTime;
-       struct timespec now;
-       gettime(&now);
-       gettime(&queryRealTime, true);
-
-#ifdef HAVE_DNSCRYPT
-        std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-        if (ci.cs->dnscryptCtx) {
-          dnsCryptQuery = std::make_shared<DNSCryptQuery>(ci.cs->dnscryptCtx);
-          uint16_t decryptedQueryLen = 0;
-          vector<uint8_t> response;
-          bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response);
-
-          if (!decrypted) {
-            if (response.size() > 0) {
-              handler.writeSizeAndMsg(response.data(), response.size(), g_tcpSendTimeout);
-            }
-            break;
-          }
-          qlen = decryptedQueryLen;
-        }
-#endif
-        struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+  boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() called without any backend selected");
+    }
+    if (d_ds->tcpSendTimeout == 0) {
+      return boost::none;
+    }
 
-        if (!checkQueryHeaders(dh)) {
-          goto drop;
-        }
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpSendTimeout;
+
+    return res;
+  }
 
-       string poolname;
-       int delayMsec=0;
+  bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
+  {
+    if (maxConnectionDuration) {
+      time_t curtime = now.tv_sec;
+      unsigned int elapsed = 0;
+      if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
+        elapsed = curtime - d_connectionStartTime.tv_sec;
+      }
+      if (elapsed >= maxConnectionDuration) {
+        return true;
+      }
+      d_remainingTime = maxConnectionDuration - elapsed;
+    }
 
-       const uint16_t* flags = getFlagsFromDNSHeader(dh);
-       uint16_t origFlags = *flags;
-       uint16_t qtype, qclass;
-       unsigned int consumed = 0;
-       DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-       DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+    return false;
+  }
 
-       if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
-         goto drop;
-       }
+  void dump() const
+  {
+    static std::mutex s_mutex;
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+
+    {
+      std::lock_guard<std::mutex> lock(s_mutex);
+      fprintf(stderr, "State is %p\n", this);
+      cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
+      cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
+      cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
+      if (d_state > State::doingHandshake) {
+        cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuery) {
+        cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
+      }
+      if (d_state > State::sendingQueryToBackend) {
+        cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
+      }
+      if (d_state > State::readingResponseFromBackend) {
+        cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
+      }
+    }
+  }
 
-       if(dq.dh->qr) { // something turned it into a response
-          fixUpQueryTurnedResponse(dq, origFlags);
+  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
+
+  std::vector<uint8_t> d_buffer;
+  std::vector<uint8_t> d_responseBuffer;
+  TCPClientThreadData& d_threadData;
+  IDState d_ids;
+  ConnectionInfo d_ci;
+  TCPIOHandler d_handler;
+  std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  struct timeval d_connectionStartTime;
+  struct timeval d_handshakeDoneTime;
+  struct timeval d_firstQuerySizeReadTime;
+  struct timeval d_querySizeReadTime;
+  struct timeval d_queryReadTime;
+  struct timeval d_querySentTime;
+  struct timeval d_responseReadTime;
+  size_t d_currentPos{0};
+  size_t d_queriesCount{0};
+  unsigned int d_remainingTime{0};
+  uint16_t d_querySize{0};
+  uint16_t d_responseSize{0};
+  uint16_t d_downstreamFailures{0};
+  State d_state{State::doingHandshake};
+  IOState d_lastIOState{IOState::Done};
+  bool d_readingFirstQuery{true};
+  bool d_outstanding{false};
+  bool d_firstResponsePacket{true};
+  bool d_isXFR{false};
+  bool d_xfrStarted{false};
+};
 
-          DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-          dr.uniqueId = dq.uniqueId;
-#endif
-          dr.qTag = dq.qTag;
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
 
-          if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-            goto drop;
-          }
+static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+
+  if (state->d_isXFR && state->d_downstreamConnection) {
+    /* we need to resume reading from the backend! */
+    state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+    state->d_currentPos = 0;
+    handleDownstreamIO(state, now);
+    return;
+  }
 
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-            goto drop;
-          }
-#endif
-          handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
-          ++g_stats.selfAnswered;
-          continue;
-        }
+  if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+    return;
+  }
 
-        std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-        std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    return;
+  }
 
-        auto policy = *(holders.policy);
-        if (serverPool->policy != nullptr) {
-          policy = *(serverPool->policy);
-        }
-        auto servers = serverPool->getServers();
-        if (policy.isLua) {
-          std::lock_guard<std::mutex> lock(g_luamutex);
-          ds = policy.policy(servers, &dq);
-        }
-        else {
-          ds = policy.policy(servers, &dq);
-        }
+  state->resetForNewQuery();
 
-        uint32_t cacheKeyNoECS = 0;
-        uint32_t cacheKey = 0;
-        boost::optional<Netmask> subnet;
-        char cachedResponse[4096];
-        uint16_t cachedResponseSize = sizeof cachedResponse;
-        uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
-        bool useZeroScope = false;
-
-        bool dnssecOK = false;
-        if (packetCache && !dq.skipCache) {
-          dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
-        }
+  handleIO(state, now);
+}
 
-        if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
-          // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-          if (packetCache && !dq.skipCache && (!ds || !ds->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-            if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-              DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-              dr.uniqueId = dq.uniqueId;
-#endif
-              dr.qTag = dq.qTag;
+static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
 
-              if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-                goto drop;
-              }
+  state->d_currentPos = 0;
 
-#ifdef HAVE_DNSCRYPT
-              if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-                goto drop;
-              }
-#endif
-              handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-              g_stats.cacheHits++;
-              continue;
-            }
-
-            if (!subnet) {
-              /* there was no existing ECS on the query, enable the zero-scope feature */
-              useZeroScope = true;
-            }
-          }
+  handleIO(state, now);
+}
 
-          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-            vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
-            goto drop;
-          }
-        }
+static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_responseSize < sizeof(dnsheader)) {
+    return;
+  }
 
-        if (packetCache && !dq.skipCache) {
-          if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
+  auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+  unsigned int consumed;
+  if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+    return;
+  }
+  state->d_firstResponsePacket = false;
 
-            if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
+  if (state->d_outstanding) {
+    --state->d_ds->outstanding;
+    state->d_outstanding = false;
+  }
 
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-            ++g_stats.cacheHits;
-            continue;
-          }
-          ++g_stats.cacheMisses;
-        }
+  auto dh = reinterpret_cast<struct dnsheader*>(response);
+  uint16_t addRoom = 0;
+  DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+  if (dr.dnsCryptQuery) {
+    addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+  }
 
-        if(!ds) {
-          ++g_stats.noPolicy;
+  dnsheader cleartextDH;
+  memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
 
-          if (g_servFailOnNoPolicy) {
-            restoreFlags(dh, origFlags);
-            dq.dh->rcode = RCode::ServFail;
-            dq.dh->qr = true;
+  std::vector<uint8_t> rewrittenResponse;
+  size_t responseSize = state->d_responseBuffer.size();
+  if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+    return;
+  }
 
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
+  if (!rewrittenResponse.empty()) {
+    /* responseSize has been updated as well but we don't really care since it will match
+       the capacity of rewrittenResponse anyway */
+    state->d_responseBuffer = std::move(rewrittenResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+  } else {
+    /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
+    state->d_responseBuffer.resize(state->d_responseSize);
+  }
 
-            if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
+  if (state->d_isXFR && !state->d_xfrStarted) {
+    /* don't bother parsing the content of the response for now */
+    state->d_xfrStarted = true;
+  }
 
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
+  sendResponse(state, now);
 
-            // no response-only statistics counter to update.
-            continue;
-          }
+  ++g_stats.responses;
+  struct timespec answertime;
+  gettime(&answertime);
+  double udiff = state->d_ids.sentTime.udiff();
+  g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
+}
 
-          break;
-        }
+static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  auto ds = state->d_ds;
+  state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
+  state->d_currentPos = 0;
+  state->d_firstResponsePacket = true;
+  state->d_downstreamConnection.reset();
+
+  if (state->d_xfrStarted) {
+    /* sorry, but we are not going to resume a XFR if we have already sent some packets
+       to the client */
+    return;
+  }
 
-        if (dq.addXPF && ds->xpfRRCode != 0) {
-          addXPF(dq, ds->xpfRRCode, g_preserveTrailingData);
-        }
+  while (state->d_downstreamFailures < state->d_ds->retries)
+  {
+    state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
 
-       int dsock = -1;
-       uint16_t downstreamFailures=0;
-#ifdef MSG_FASTOPEN
-       bool freshConn = true;
-#endif /* MSG_FASTOPEN */
-       if(sockets.count(ds->remote) == 0) {
-         dsock=setupTCPDownstream(ds, downstreamFailures);
-         sockets[ds->remote]=dsock;
-       }
-       else {
-         dsock=sockets[ds->remote];
-#ifdef MSG_FASTOPEN
-         freshConn = false;
-#endif /* MSG_FASTOPEN */
-        }
+    if (!state->d_downstreamConnection) {
+      ++ds->tcpGaveUp;
+      ++state->d_ci.cs->tcpGaveUp;
+      vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+      return;
+    }
 
-        ds->queries++;
-        ds->outstanding++;
-        outstanding = true;
+    handleDownstreamIO(state, now);
+    return;
+  }
 
-      retry:; 
-        if (dsock < 0) {
-          sockets.erase(ds->remote);
-          break;
-        }
+  ++ds->tcpGaveUp;
+  ++state->d_ci.cs->tcpGaveUp;
+  vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+}
 
-        if (ds->retries > 0 && downstreamFailures > ds->retries) {
-          vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstreamFailures);
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          break;
-        }
+static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_querySize < sizeof(dnsheader)) {
+    ++g_stats.nonCompliantQueries;
+    return;
+  }
 
-        try {
-          int socketFlags = 0;
-#ifdef MSG_FASTOPEN
-          if (ds->tcpFastOpen && freshConn) {
-            socketFlags |= MSG_FASTOPEN;
-          }
-#endif /* MSG_FASTOPEN */
-          sendSizeAndMsgWithTimeout(dsock, dq.len, query, ds->tcpSendTimeout, &ds->remote, &ds->sourceAddr, ds->sourceItf, 0, socketFlags);
-        }
-        catch(const runtime_error& e) {
-          vinfolog("Downstream connection to %s died on us (%s), getting a new one!", ds->getName(), e.what());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
-#ifdef MSG_FASTOPEN
-          freshConn=true;
-#endif /* MSG_FASTOPEN */
-          goto retry;
-        }
+  state->d_readingFirstQuery = false;
+  ++state->d_queriesCount;
+  ++state->d_ci.cs->queries;
+  ++g_stats.queries;
+
+  /* we need an accurate ("real") value for the response and
+     to store into the IDS, but not for insertion into the
+     rings for example */
+  struct timespec queryRealTime;
+  gettime(&queryRealTime, true);
+
+  auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
+  if (dnsCryptResponse) {
+    state->d_responseBuffer = std::move(*dnsCryptResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state, now);
+    return;
+  }
 
-        bool xfrStarted = false;
-        bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
-        if (isXFR) {
-          dq.skipCache = true;
-        }
-        bool firstPacket=true;
-      getpacket:;
-
-        if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
-         vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
+  const auto& dh = reinterpret_cast<dnsheader*>(query);
+  if (!checkQueryHeaders(dh)) {
+    return;
+  }
+
+  uint16_t qtype, qclass;
+  unsigned int consumed = 0;
+  DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+  dq.dnsCryptQuery = std::move(dnsCryptQuery);
+
+  state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
+  if (state->d_isXFR) {
+    dq.skipCache = true;
+  }
+
+  state->d_ds.reset();
+  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
+
+  if (result == ProcessQueryResult::Drop) {
+    return;
+  }
+
+  if (result == ProcessQueryResult::SendAnswer) {
+    state->d_buffer.resize(dq.len);
+    state->d_responseBuffer = std::move(state->d_buffer);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state, now);
+    return;
+  }
+
+  if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
+    return;
+  }
+
+  state->d_buffer.resize(dq.len);
+  setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
+
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
+  sendQueryToBackend(state, now);
+}
+
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
+{
+  //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
+
+  if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
+    state->d_threadData.mplexer->removeReadFD(fd);
+    //cerr<<__func__<<": remove read FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
+  else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
+    state->d_threadData.mplexer->removeWriteFD(fd);
+    //cerr<<__func__<<": remove write FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
+
+  if (iostate == IOState::NeedRead) {
+    if (state->d_lastIOState == IOState::NeedRead) {
+      if (ttd) {
+        /* let's update the TTD ! */
+        state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+      }
+      return;
+    }
+
+    state->d_lastIOState = IOState::NeedRead;
+    //cerr<<__func__<<": add read FD "<<fd<<endl;
+    state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::NeedWrite) {
+    if (state->d_lastIOState == IOState::NeedWrite) {
+      return;
+    }
+
+    state->d_lastIOState = IOState::NeedWrite;
+    //cerr<<__func__<<": add write FD "<<fd<<endl;
+    state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::Done) {
+    state->d_lastIOState = IOState::Done;
+  }
+}
+
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_downstreamConnection == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+
+  int fd = state->d_downstreamConnection->getHandle();
+  IOState iostate = IOState::Done;
+  bool connectionDied = false;
+
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+      int socketFlags = 0;
 #ifdef MSG_FASTOPEN
-          freshConn=true;
+      if (state->d_ds->tcpFastOpen && state->d_downstreamConnection->isFresh()) {
+        socketFlags |= MSG_FASTOPEN;
+      }
 #endif /* MSG_FASTOPEN */
-          if(xfrStarted) {
-            goto drop;
-          }
-          goto retry;
-        }
 
-        size_t responseSize = rlen;
-        uint16_t addRoom = 0;
-#ifdef HAVE_DNSCRYPT
-        if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
-          addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
-        }
-#endif
-        responseSize += addRoom;
-        answerBuffer.resize(responseSize);
-        char* response = answerBuffer.data();
-        readn2WithTimeout(dsock, response, rlen, ds->tcpRecvTimeout);
-        uint16_t responseLen = rlen;
-        if (outstanding) {
-          /* might be false for {A,I}XFR */
-          --ds->outstanding;
-          outstanding = false;
+      size_t sent = sendMsgWithTimeout(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, 0, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, 0, socketFlags);
+      if (sent == state->d_buffer.size()) {
+        /* request sent ! */
+        state->d_downstreamConnection->incQueries();
+        state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+        state->d_currentPos = 0;
+        state->d_querySentTime = now;
+        iostate = IOState::NeedRead;
+        if (!state->d_isXFR) {
+          /* don't bother with the outstanding count for XFR queries */
+          ++state->d_ds->outstanding;
+          state->d_outstanding = true;
         }
+      }
+      else {
+        state->d_currentPos += sent;
+        iostate = IOState::NeedWrite;
+        /* disable fast open on partial write */
+        state->d_downstreamConnection->setReused();
+      }
+    }
 
-        if (rlen < sizeof(dnsheader)) {
-          break;
-        }
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
+      // then we need to allocate a new buffer (new because we might need to re-send the query if the
+      // backend dies on us
+      // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+      // should very likely be a TCPIOHandler d_downstreamHandler
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
+        state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
+        state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
+        state->d_currentPos = 0;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
 
-        consumed = 0;
-        if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) {
-          break;
+        if (state->d_isXFR) {
+          /* Don't reuse the TCP connection after an {A,I}XFR */
+          /* but don't reset it either, we will need to read more messages */
         }
-        firstPacket=false;
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom, useZeroScope ? &zeroScope : nullptr)) {
-          break;
+        else {
+          releaseDownstreamConnection(std::move(state->d_downstreamConnection));
         }
+        fd = -1;
 
-        dh = (struct dnsheader*) response;
-        DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = dq.uniqueId;
-#endif
-        dr.qTag = dq.qTag;
+        state->d_responseReadTime = now;
+        handleResponse(state, now);
+        return;
+      }
+    }
 
-        if (!processResponse(localRespRulactions, dr, &delayMsec)) {
-          break;
-        }
+    if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
+    }
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
+    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+      ++state->d_ds->tcpDiedSendingQuery;
+    }
+    else {
+      ++state->d_ds->tcpDiedReadingResponse;
+    }
 
-       if (packetCache && !dq.skipCache) {
-          if (!useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          packetCache->insert(zeroScope ? cacheKeyNoECS : cacheKey, zeroScope ? boost::none : subnet, origFlags, dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
-       }
+    /* don't increase this counter when reusing connections */
+    if (state->d_downstreamConnection->isFresh()) {
+      ++state->d_downstreamFailures;
+    }
+    if (state->d_outstanding && state->d_ds != nullptr) {
+      --state->d_ds->outstanding;
+      state->d_outstanding = false;
+    }
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+    connectionDied = true;
+  }
+
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
+  }
+
+  if (connectionDied) {
+    sendQueryToBackend(state, now);
+  }
+}
+
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (state->d_downstreamConnection == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+  if (fd != state->d_downstreamConnection->getHandle()) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
+  }
 
-#ifdef HAVE_DNSCRYPT
-        if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery, &dh, &dhCopy)) {
-          goto drop;
+  struct timeval now;
+  gettimeofday(&now, 0);
+  handleDownstreamIO(state, now);
+}
+
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  int fd = state->d_ci.fd;
+  IOState iostate = IOState::Done;
+
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+    return;
+  }
+
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+      iostate = state->d_handler.tryHandshake();
+      if (iostate == IOState::Done) {
+        state->d_handshakeDoneTime = now;
+        state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingQuery;
+        state->d_querySizeReadTime = now;
+        if (state->d_queriesCount == 0) {
+          state->d_firstQuerySizeReadTime = now;
         }
-#endif
-        if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
-          break;
+        state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+        if (state->d_querySize < sizeof(dnsheader)) {
+          /* go away */
+          handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+          return;
         }
 
-        if (isXFR) {
-          if (dh->rcode == 0 && dh->ancount != 0) {
-            if (xfrStarted == false) {
-              xfrStarted = true;
-              if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
-                goto getpacket;
-              }
-            }
-            else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
-              goto getpacket;
-            }
-          }
-          /* Don't reuse the TCP connection after an {A,I}XFR */
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-        }
+        /* allocate a bit more memory to be able to spoof the content,
+           or to add ECS without allocating a new buffer */
+        state->d_buffer.resize(state->d_querySize + 512);
+        state->d_currentPos = 0;
+      }
+    }
 
-        ++g_stats.responses;
-        struct timespec answertime;
-        gettime(&answertime);
-        unsigned int udiff = 1000000.0*DiffTime(now,answertime);
-        g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote);
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+        handleQuery(state, now);
+        return;
+      }
+    }
 
-        rewrittenResponse.clear();
+    if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
+      if (iostate == IOState::Done) {
+        handleResponseSent(state, now);
+        return;
       }
     }
-    catch(...) {}
 
-  drop:;
+    if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+        state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+      vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
+    }
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
+        state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+        state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      ++state->d_ci.cs->tcpDiedReadingQuery;
+    }
+    else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      ++state->d_ci.cs->tcpDiedSendingResponse;
+    }
+
+    if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+      vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+    }
+    else {
+      vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+    }
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+  }
+
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+  }
+}
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (fd != state->d_ci.fd) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
+  }
+  struct timeval now;
+  gettimeofday(&now, 0);
+
+  handleIO(state, now);
+}
+
+static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+  auto threadData = boost::any_cast<TCPClientThreadData*>(param);
+
+  ConnectionInfo* citmp{nullptr};
+
+  ssize_t got = read(pipefd, &citmp, sizeof(citmp));
+  if (got == 0) {
+    throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+  else if (got == -1) {
+    if (errno == EAGAIN || errno == EINTR) {
+      return;
+    }
+    throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + strerror(errno));
+  }
+  else if (got != sizeof(citmp)) {
+    throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+
+  try {
+    g_tcpclientthreads->decrementQueuedCount();
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
+    delete citmp;
+    citmp = nullptr;
+
+    /* let's update the remaining time */
+    state->d_remainingTime = g_maxTCPConnectionDuration;
+
+    handleIO(state, now);
+  }
+  catch(...) {
+    delete citmp;
+    citmp = nullptr;
+    throw;
+  }
+}
 
-    vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
+void tcpClientThread(int pipefd)
+{
+  /* we get launched with a pipe on which we receive file descriptors from clients that we own
+     from that point on */
+
+  setThreadName("dnsdist/tcpClie");
 
-    if (ds && outstanding) {
-      outstanding = false;
-      --ds->outstanding;
+  TCPClientThreadData data;
+
+  data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
+  struct timeval now;
+  gettimeofday(&now, 0);
+  time_t lastTCPCleanup = now.tv_sec;
+  time_t lastTimeoutScan = now.tv_sec;
+
+  for (;;) {
+    data.mplexer->run(&now);
+
+    if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
+      cleanupClosedTCPConnections();
+      lastTCPCleanup = now.tv_sec;
     }
-    decrementTCPClientCount(ci.remote);
 
-    if (g_downstreamTCPCleanupInterval > 0 && (connectionStartTime > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
-      cleanupClosedTCPConnections(sockets);
-      lastTCPCleanup = time(nullptr);
+    if (now.tv_sec > lastTimeoutScan) {
+      lastTimeoutScan = now.tv_sec;
+      auto expiredReadConns = data.mplexer->getTimeouts(now, false);
+      for(const auto& conn : expiredReadConns) {
+        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+        if (conn.first == state->d_ci.fd) {
+          vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+          ++state->d_ci.cs->tcpClientTimeouts;
+        }
+        else if (state->d_ds) {
+          vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
+          ++state->d_ci.cs->tcpDownstreamTimeouts;
+          ++state->d_ds->tcpReadTimeouts;
+        }
+        data.mplexer->removeReadFD(conn.first);
+        state->d_lastIOState = IOState::Done;
+      }
+
+      auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
+      for(const auto& conn : expiredWriteConns) {
+        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+        if (conn.first == state->d_ci.fd) {
+          vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+          ++state->d_ci.cs->tcpClientTimeouts;
+        }
+        else if (state->d_ds) {
+          vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
+          ++state->d_ci.cs->tcpDownstreamTimeouts;
+          ++state->d_ds->tcpWriteTimeouts;
+        }
+        data.mplexer->removeWriteFD(conn.first);
+        state->d_lastIOState = IOState::Done;
+      }
     }
   }
 }
 
-/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and 
+/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
    they will hand off to worker threads & spawn more of them if required
 */
 void tcpAcceptorThread(void* p)
@@ -748,7 +1237,7 @@ void tcpAcceptorThread(void* p)
   bool tcpClientCountIncremented = false;
   ComboAddress remote;
   remote.sin4.sin_family = cs->local.sin4.sin_family;
-  
+
   g_tcpclientthreads->addTCPClientThread();
 
   auto acl = g_ACL.getLocal();
@@ -758,13 +1247,14 @@ void tcpAcceptorThread(void* p)
     tcpClientCountIncremented = false;
     try {
       socklen_t remlen = remote.getSocklen();
-      ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo);
-      ci->cs = cs;
+      ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
 #ifdef HAVE_ACCEPT4
-      ci->fd = accept4(cs->tcpFD, (struct sockaddr*)&remote, &remlen, SOCK_NONBLOCK);
+      ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
 #else
-      ci->fd = accept(cs->tcpFD, (struct sockaddr*)&remote, &remlen);
+      ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
 #endif
+      ++cs->tcpCurrentConnections;
+
       if(ci->fd < 0) {
         throw std::runtime_error((boost::format("accepting new connection on socket: %s") % strerror(errno)).str());
       }
@@ -821,7 +1311,7 @@ void tcpAcceptorThread(void* p)
         }
       }
     }
-    catch(std::exception& e) {
+    catch(const std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
       if(tcpClientCountIncremented) {
         decrementTCPClientCount(remote);
index 9cb51a622b17d5f1b37fb251402a9737284746f7..704c575172a0533690e185b4ea777bc56f7c635d 100644 (file)
@@ -446,20 +446,36 @@ static void connectionThread(int sock, ComboAddress remote)
         auto states = g_dstates.getLocal();
         const string statesbase = "dnsdist_server_";
 
-        output << "# HELP " << statesbase << "queries "     << "Amount of queries relayed to server"                               << "\n";
-        output << "# TYPE " << statesbase << "queries "     << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "drops "       << "Amount of queries not answered by server"                          << "\n";
-        output << "# TYPE " << statesbase << "drops "       << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "latency "     << "Server's latency when answering questions in miliseconds"          << "\n";
-        output << "# TYPE " << statesbase << "latency "     << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "senderrors "  << "Total number of OS snd errors while relaying queries"              << "\n";
-        output << "# TYPE " << statesbase << "senderrors "  << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "outstanding " << "Current number of queries that are waiting for a backend response" << "\n";
-        output << "# TYPE " << statesbase << "outstanding " << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "order "       << "The order in which this server is picked"                          << "\n";
-        output << "# TYPE " << statesbase << "order "       << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "weight "      << "The weight within the order in which this server is picked"        << "\n";
-        output << "# TYPE " << statesbase << "weight "      << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "queries "                << "Amount of queries relayed to server"                               << "\n";
+        output << "# TYPE " << statesbase << "queries "                << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "drops "                  << "Amount of queries not answered by server"                          << "\n";
+        output << "# TYPE " << statesbase << "drops "                  << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "latency "                << "Server's latency when answering questions in miliseconds"          << "\n";
+        output << "# TYPE " << statesbase << "latency "                << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "senderrors "             << "Total number of OS snd errors while relaying queries"              << "\n";
+        output << "# TYPE " << statesbase << "senderrors "             << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "outstanding "            << "Current number of queries that are waiting for a backend response" << "\n";
+        output << "# TYPE " << statesbase << "outstanding "            << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "order "                  << "The order in which this server is picked"                          << "\n";
+        output << "# TYPE " << statesbase << "order "                  << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "weight "                 << "The weight within the order in which this server is picked"        << "\n";
+        output << "# TYPE " << statesbase << "weight "                 << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpdiedsendingquery "    << "The number of TCP I/O errors while sending the query"              << "\n";
+        output << "# TYPE " << statesbase << "tcpdiedsendingquery "    << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpdiedreadingresponse " << "The number of TCP I/O errors while reading the response"           << "\n";
+        output << "# TYPE " << statesbase << "tcpdiedreadingresponse " << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpgaveup "              << "The number of TCP connections failing after too many attempts"     << "\n";
+        output << "# TYPE " << statesbase << "tcpgaveup "              << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpreadtimeouts "        << "The number of TCP read timeouts"                                   << "\n";
+        output << "# TYPE " << statesbase << "tcpreadtimeouts "        << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpwritetimeouts "       << "The number of TCP write timeouts"                                  << "\n";
+        output << "# TYPE " << statesbase << "tcpwritetimeouts "       << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpcurrentconnections "  << "The number of current TCP connections"                             << "\n";
+        output << "# TYPE " << statesbase << "tcpcurrentconnections "  << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpavgqueriesperconn "   << "The average number of queries per TCP connection"                  << "\n";
+        output << "# TYPE " << statesbase << "tcpavgqueriesperconn "   << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpavgconnduration "     << "The average duration of a TCP connection (ms)"                     << "\n";
+        output << "# TYPE " << statesbase << "tcpavgconnduration "     << "gauge"                                                             << "\n";
 
         for (const auto& state : *states) {
           string serverName;
@@ -474,13 +490,21 @@ static void connectionThread(int sock, ComboAddress remote)
           const std::string label = boost::str(boost::format("{server=\"%1%\",address=\"%2%\"}")
             % serverName % state->remote.toStringWithPort());
 
-          output << statesbase << "queries"     << label << " " << state->queries.load()     << "\n";
-          output << statesbase << "drops"       << label << " " << state->reuseds.load()     << "\n";
-          output << statesbase << "latency"     << label << " " << state->latencyUsec/1000.0 << "\n";
-          output << statesbase << "senderrors"  << label << " " << state->sendErrors.load()  << "\n";
-          output << statesbase << "outstanding" << label << " " << state->outstanding.load() << "\n";
-          output << statesbase << "order"       << label << " " << state->order              << "\n";
-          output << statesbase << "weight"      << label << " " << state->weight             << "\n";
+          output << statesbase << "queries"                << label << " " << state->queries.load()             << "\n";
+          output << statesbase << "drops"                  << label << " " << state->reuseds.load()             << "\n";
+          output << statesbase << "latency"                << label << " " << state->latencyUsec/1000.0         << "\n";
+          output << statesbase << "senderrors"             << label << " " << state->sendErrors.load()          << "\n";
+          output << statesbase << "outstanding"            << label << " " << state->outstanding.load()         << "\n";
+          output << statesbase << "order"                  << label << " " << state->order                      << "\n";
+          output << statesbase << "weight"                 << label << " " << state->weight                     << "\n";
+          output << statesbase << "tcpdiedsendingquery"    << label << " " << state->tcpDiedSendingQuery        << "\n";
+          output << statesbase << "tcpdiedreadingresponse" << label << " " << state->tcpDiedReadingResponse     << "\n";
+          output << statesbase << "tcpgaveup"              << label << " " << state->tcpGaveUp                  << "\n";
+          output << statesbase << "tcpreadtimeouts"        << label << " " << state->tcpReadTimeouts            << "\n";
+          output << statesbase << "tcpwritetimeouts"       << label << " " << state->tcpWriteTimeouts           << "\n";
+          output << statesbase << "tcpcurrentconnections"  << label << " " << state->tcpCurrentConnections      << "\n";
+          output << statesbase << "tcpavgqueriesperconn"   << label << " " << state->tcpAvgQueriesPerConnection << "\n";
+          output << statesbase << "tcpavgconnduration"     << label << " " << state->tcpAvgConnectionDuration   << "\n";
         }
 
         for (const auto& front : g_frontends) {
@@ -562,6 +586,14 @@ static void connectionThread(int sock, ComboAddress remote)
           {"latency", (double)(a->latencyUsec/1000.0)},
           {"queries", (double)a->queries},
           {"sendErrors", (double)a->sendErrors},
+          {"tcpDiedSendingQuery", (double)a->tcpDiedSendingQuery},
+          {"tcpDiedReadingResponse", (double)a->tcpDiedReadingResponse},
+          {"tcpGaveUp", (double)a->tcpGaveUp},
+          {"tcpReadTimeouts", (double)a->tcpReadTimeouts},
+          {"tcpWriteTimeouts", (double)a->tcpWriteTimeouts},
+          {"tcpCurrentConnections", (double)a->tcpCurrentConnections},
+          {"tcpAvgQueriesPerConnection", (double)a->tcpAvgQueriesPerConnection},
+          {"tcpAvgConnectionDuration", (double)a->tcpAvgConnectionDuration},
           {"dropRate", (double)a->dropRate}
         };
 
@@ -583,7 +615,16 @@ static void connectionThread(int sock, ComboAddress remote)
           { "address", front->local.toStringWithPort() },
           { "udp", front->udpFD >= 0 },
           { "tcp", front->tcpFD >= 0 },
-          { "queries", (double) front->queries.load() }
+          { "type", front->getType() },
+          { "queries", (double) front->queries.load() },
+          { "tcpDiedReadingQuery", (double) front->tcpDiedReadingQuery.load() },
+          { "tcpDiedSendingResponse", (double) front->tcpDiedSendingResponse.load() },
+          { "tcpGaveUp", (double) front->tcpGaveUp.load() },
+          { "tcpClientTimeouts", (double) front->tcpClientTimeouts },
+          { "tcpDownstreamTimeouts", (double) front->tcpDownstreamTimeouts },
+          { "tcpCurrentConnections", (double) front->tcpCurrentConnections },
+          { "tcpAvgQueriesPerConnection", (double) front->tcpAvgQueriesPerConnection },
+          { "tcpAvgConnectionDuration", (double) front->tcpAvgConnectionDuration },
         };
         frontends.push_back(frontend);
       }
@@ -640,9 +681,14 @@ static void connectionThread(int sock, ComboAddress remote)
         acl+=s;
       }
       string localaddresses;
-      for(const auto& loc : g_locals) {
-        if(!localaddresses.empty()) localaddresses += ", ";
-        localaddresses += std::get<0>(loc).toStringWithPort();
+      for(const auto& front : g_frontends) {
+        if (front->tcp) {
+          continue;
+        }
+        if (!localaddresses.empty()) {
+          localaddresses += ", ";
+        }
+        localaddresses += front->local.toStringWithPort();
       }
 
       Json my_json = Json::object {
index f2dfa39b218c0f3700345d9c20a9ece48c06d811..f7bfc1dee545f193182f8a758dbdfe520783a99c 100644 (file)
@@ -92,16 +92,13 @@ bool g_allowEmptyResponse{false};
 GlobalStateHolder<NetmaskGroup> g_ACL;
 string g_outputBuffer;
 
-vector<std::tuple<ComboAddress, bool, bool, int, string, std::set<int>>> g_locals;
 std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
-#ifdef HAVE_DNSCRYPT
-std::vector<std::tuple<ComboAddress,std::shared_ptr<DNSCryptContext>,bool, int, string, std::set<int> >> g_dnsCryptLocals;
-#endif
+std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
 #ifdef HAVE_EBPF
 shared_ptr<BPFFilter> g_defaultBPFFilter;
 std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
 #endif /* HAVE_EBPF */
-vector<ClientState *> g_frontends;
+std::vector<std::unique_ptr<ClientState>> g_frontends;
 GlobalStateHolder<pools_t> g_pools;
 size_t g_udpVectorSize{1};
 
@@ -144,6 +141,7 @@ bool g_servFailOnNoPolicy{false};
 bool g_truncateTC{false};
 bool g_fixupCase{false};
 bool g_preserveTrailingData{false};
+bool g_roundrobinFailOnNoServer{false};
 
 static void truncateTC(char* packet, uint16_t* len, size_t responseSize, unsigned int consumed)
 try
@@ -156,7 +154,7 @@ try
     hadEDNS = getEDNSUDPPayloadSizeAndZ(packet, *len, &payloadSize, &z);
   }
 
-  *len=(uint16_t) (sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
+  *len=static_cast<uint16_t>(sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
   struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
   dh->ancount = dh->arcount = dh->nscount = 0;
 
@@ -191,7 +189,7 @@ struct DelayedPacket
   }
 };
 
-DelayPipe<DelayedPacket> * g_delay = 0;
+DelayPipe<DelayedPacket>* g_delay = nullptr;
 
 void doLatencyStats(double udiff)
 {
@@ -214,14 +212,11 @@ void doLatencyStats(double udiff)
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed)
 {
-  uint16_t rqtype, rqclass;
-  DNSName rqname;
-  const struct dnsheader* dh = (struct dnsheader*) response;
-
   if (responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response);
   if (dh->qdcount == 0) {
     if ((dh->rcode != RCode::NoError && dh->rcode != RCode::NXDomain) || g_allowEmptyResponse) {
       return true;
@@ -232,12 +227,15 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
     }
   }
 
+  uint16_t rqtype, rqclass;
+  DNSName rqname;
   try {
     rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
   }
-  catch(std::exception& e) {
-    if(responseLen > (ssize_t)sizeof(dnsheader))
+  catch(const std::exception& e) {
+    if(responseLen > 0 && static_cast<size_t>(responseLen) > sizeof(dnsheader)) {
       infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
+    }
     ++g_stats.nonCompliantResponses;
     return false;
   }
@@ -249,7 +247,7 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
   return true;
 }
 
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
+static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
 {
   static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
   static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
@@ -263,21 +261,20 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
 {
   restoreFlags(dq.dh, origFlags);
 
   return addEDNSToQueryTurnedResponse(dq);
 }
 
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
+static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
 {
-  struct dnsheader* dh = (struct dnsheader*) *response;
-
   if (*responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(*response);
   restoreFlags(dh, origFlags);
 
   if (*responseLen == sizeof(dnsheader)) {
@@ -368,7 +365,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
 }
 
 #ifdef HAVE_DNSCRYPT
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
+static bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
 {
   if (dnsCryptQuery) {
     uint16_t encryptedResponseLen = 0;
@@ -390,9 +387,82 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize,
   }
   return true;
 }
-#endif
+#endif /* HAVE_DNSCRYPT */
 
-static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+static bool applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr)
+{
+  DNSResponseAction::Action action=DNSResponseAction::Action::None;
+  std::string ruleresult;
+  for(const auto& lr : *localRespRulactions) {
+    if(lr.d_rule->matches(&dr)) {
+      lr.d_rule->d_matches++;
+      action=(*lr.d_action)(&dr, &ruleresult);
+      switch(action) {
+      case DNSResponseAction::Action::Allow:
+        return true;
+        break;
+      case DNSResponseAction::Action::Drop:
+        return false;
+        break;
+      case DNSResponseAction::Action::HeaderModify:
+        return true;
+        break;
+      case DNSResponseAction::Action::ServFail:
+        dr.dh->rcode = RCode::ServFail;
+        return true;
+        break;
+        /* non-terminal actions follow */
+      case DNSResponseAction::Action::Delay:
+        dr.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        break;
+      case DNSResponseAction::Action::None:
+        break;
+      }
+    }
+  }
+
+  return true;
+}
+
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted)
+{
+  if (!applyRulesToResponse(localRespRulactions, dr)) {
+    return false;
+  }
+
+  bool zeroScope = false;
+  if (!fixUpResponse(response, responseLen, responseSize, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, rewrittenResponse, addRoom, dr.useZeroScope ? &zeroScope : nullptr)) {
+    return false;
+  }
+
+  if (dr.packetCache && !dr.skipCache) {
+    if (!dr.useZeroScope) {
+      /* if the query was not suitable for zero-scope, for
+         example because it had an existing ECS entry so the hash is
+         not really 'no ECS', so just insert it for the existing subnet
+         since:
+         - we don't have the correct hash for a non-ECS query
+         - inserting with hash computed before the ECS replacement but with
+         the subnet extracted _after_ the replacement would not work.
+      */
+      zeroScope = false;
+    }
+    // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
+    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, *response, *responseLen, dr.tcp, dr.dh->rcode, dr.tempFailureTTL);
+  }
+
+#ifdef HAVE_DNSCRYPT
+  if (!muted) {
+    if (!encryptResponse(*response, responseLen, *responseSize, dr.tcp, dr.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
+    }
+  }
+#endif /* HAVE_DNSCRYPT */
+
+  return true;
+}
+
+static bool sendUDPResponse(int origFD, const char* response, const uint16_t responseLen, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   if(delayMsec && g_delay) {
     DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
@@ -401,7 +471,7 @@ static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, in
   else {
     ssize_t res;
     if(origDest.sin4.sin_family == 0) {
-      res = sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen());
+      res = sendto(origFD, response, responseLen, 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
     }
     else {
       res = sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
@@ -416,7 +486,7 @@ static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, in
 }
 
 
-static int pickBackendSocketForSending(DownstreamState* state)
+static int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
 {
   return state->sockets[state->socketsOffset++ % state->sockets.size()];
 }
@@ -441,15 +511,11 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 try {
   setThreadName("dnsdist/respond");
   auto localRespRulactions = g_resprulactions.getLocal();
-#ifdef HAVE_DNSCRYPT
   char packet[4096 + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE];
+  static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
   /* when the answer is encrypted in place, we need to get a copy
      of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
-#else
-  char packet[4096];
-#endif
-  static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
+  dnsheader cleartextDH;
   vector<uint8_t> rewrittenResponse;
 
   uint16_t queryId = 0;
@@ -465,10 +531,10 @@ try {
         char * response = packet;
         size_t responseSize = sizeof(packet);
 
-        if (got < (ssize_t) sizeof(dnsheader))
+        if (got < 0 || static_cast<size_t>(got) < sizeof(dnsheader))
           continue;
 
-        uint16_t responseLen = (uint16_t) got;
+        uint16_t responseLen = static_cast<uint16_t>(got);
         queryId = dh->id;
 
         if(queryId >= dss->idStates.size())
@@ -508,80 +574,49 @@ try {
         dh->id = ids->origID;
 
         uint16_t addRoom = 0;
-        DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = ids->uniqueId;
-#endif
-        dr.qTag = ids->qTag;
-
-        if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
-          continue;
-        }
-
-#ifdef HAVE_DNSCRYPT
-        if (ids->dnsCryptQuery) {
+        DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false);
+        if (dr.dnsCryptQuery) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
         }
-#endif
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, ids->ecsAdded, rewrittenResponse, addRoom, ids->useZeroScope ? &zeroScope : nullptr)) {
-          continue;
-        }
 
-        if (ids->packetCache && !ids->skipCache) {
-          if (!ids->useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          ids->packetCache->insert(zeroScope ? ids->cacheKeyNoECS : ids->cacheKey, zeroScope ? boost::none : ids->subnet, ids->origFlags, ids->dnssecOK, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
+        memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
+        if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, ids->cs && ids->cs->muted)) {
+          continue;
         }
 
         if (ids->cs && !ids->cs->muted) {
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery, &dh, &dhCopy)) {
-            continue;
-          }
-#endif
-
           ComboAddress empty;
           empty.sin4.sin_family = 0;
           /* if ids->destHarvested is false, origDest holds the listening address.
              We don't want to use that as a source since it could be 0.0.0.0 for example. */
-          sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
+          sendUDPResponse(origFD, response, responseLen, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
         }
 
         ++g_stats.responses;
 
         double udiff = ids->sentTime.udiff();
-        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
+        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), dr.remote->toStringWithPort(), udiff);
 
         struct timespec ts;
         gettime(&ts);
-        g_rings.insertResponse(ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, dss->remote);
+        g_rings.insertResponse(ts, *dr.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(got), cleartextDH, dss->remote);
 
-        if(dh->rcode == RCode::ServFail) {
+        switch (dh->rcode) {
+        case RCode::NXDomain:
+          ++g_stats.frontendNXDomain;
+          break;
+        case RCode::ServFail:
           ++g_stats.servfailResponses;
+          ++g_stats.frontendServFail;
+          break;
+        case RCode::NoError:
+          ++g_stats.frontendNoError;
+          break;
         }
         dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
 
         doLatencyStats(udiff);
 
-        /* if the FD is not -1, the state has been actively reused and we should
-           not alter anything */
-        if (ids->origFD == -1) {
-#ifdef HAVE_DNSCRYPT
-          ids->dnsCryptQuery = nullptr;
-#endif
-        }
-
         rewrittenResponse.clear();
       }
     }
@@ -849,7 +884,7 @@ shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, cons
   }
 
   const auto *res=&poss;
-  if(poss.empty())
+  if(poss.empty() && !g_roundrobinFailOnNoServer)
     res = &servers;
 
   if(res->empty())
@@ -962,16 +997,16 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
   }
 }
 
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
+static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, const struct timespec& now)
 {
-  g_rings.insertQuery(now,*dq.remote,*dq.qname,dq.qtype,dq.len,*dq.dh);
+  g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
 
   if(g_qcount.enabled) {
     string qname = (*dq.qname).toString(".");
     bool countQuery{true};
     if(g_qcount.filter) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      std::tie (countQuery, qname) = g_qcount.filter(dq);
+      std::tie (countQuery, qname) = g_qcount.filter(&dq);
     }
 
     if(countQuery) {
@@ -1146,7 +1181,7 @@ bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int*
         break;
         /* non-terminal actions follow */
       case DNSAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
         break;
       case DNSAction::Action::None:
         /* fall-through */
@@ -1163,42 +1198,7 @@ bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int*
   return true;
 }
 
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec)
-{
-  DNSResponseAction::Action action=DNSResponseAction::Action::None;
-  std::string ruleresult;
-  for(const auto& lr : *localRespRulactions) {
-    if(lr.d_rule->matches(&dr)) {
-      lr.d_rule->d_matches++;
-      action=(*lr.d_action)(&dr, &ruleresult);
-      switch(action) {
-      case DNSResponseAction::Action::Allow:
-        return true;
-        break;
-      case DNSResponseAction::Action::Drop:
-        return false;
-        break;
-      case DNSResponseAction::Action::HeaderModify:
-        return true;
-        break;
-      case DNSResponseAction::Action::ServFail:
-        dr.dh->rcode = RCode::ServFail;
-        return true;
-        break;
-        /* non-terminal actions follow */
-      case DNSResponseAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
-        break;
-      case DNSResponseAction::Action::None:
-        break;
-      }
-    }
-  }
-
-  return true;
-}
-
-static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
+static ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
 {
   ssize_t result;
 
@@ -1261,29 +1261,29 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
   return true;
 }
 
-#ifdef HAVE_DNSCRYPT
-static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now)
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
 {
   if (cs.dnscryptCtx) {
+#ifdef HAVE_DNSCRYPT
     vector<uint8_t> response;
     uint16_t decryptedQueryLen = 0;
 
     dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
 
-    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response);
+    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response);
 
     if (!decrypted) {
       if (response.size() > 0) {
-        sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(response.data()), static_cast<uint16_t>(response.size()), 0, dest, remote);
+        return response;
       }
-      return false;
+      throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
     }
 
     len = decryptedQueryLen;
+#endif /* HAVE_DNSCRYPT */
   }
-  return true;
+  return boost::none;
 }
-#endif /* HAVE_DNSCRYPT */
 
 bool checkQueryHeaders(const struct dnsheader* dh)
 {
@@ -1319,109 +1319,82 @@ static void queueResponse(const ClientState& cs, const char* response, uint16_t
 }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
 
-static int sendAndEncryptUDPResponse(LocalHolders& holders, ClientState& cs, const DNSQuestion& dq, char* response, uint16_t responseLen, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, int delayMsec, const ComboAddress& dest, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf, bool cacheHit)
+/* self-generated responses or cache hits */
+static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
 {
-  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, dq.queryTime);
+  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(dq.dh), dq.size, dq.len, dq.tcp, dq.queryTime);
+
 #ifdef HAVE_PROTOBUF
   dr.uniqueId = dq.uniqueId;
 #endif
   dr.qTag = dq.qTag;
+  dr.delayMsec = dq.delayMsec;
 
-  if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-    return -1;
+  if (!applyRulesToResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr)) {
+    return false;
   }
 
-  if (!cs.muted) {
+  /* in case a rule changed it */
+  dq.delayMsec = dr.delayMsec;
+
 #ifdef HAVE_DNSCRYPT
-    if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
-      return -1;
-    }
-#endif
-#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-    if (delayMsec == 0 && responsesVect != nullptr) {
-      queueResponse(cs, response, responseLen, dest, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
-      (*queuedResponses)++;
+  if (!cs.muted) {
+    if (!encryptResponse(reinterpret_cast<char*>(dq.dh), &dq.len, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
     }
-    else
-#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-      {
-        sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, *dq.remote);
-      }
   }
+#endif /* HAVE_DNSCRYPT */
 
   if (cacheHit) {
     ++g_stats.cacheHits;
   }
+
+  switch (dr.dh->rcode) {
+  case RCode::NXDomain:
+    ++g_stats.frontendNXDomain;
+    break;
+  case RCode::ServFail:
+    ++g_stats.frontendServFail;
+    break;
+  case RCode::NoError:
+    ++g_stats.frontendNoError;
+    break;
+  }
+
   doLatencyStats(0);  // we're not going to measure this
-  return 0;
+  return true;
 }
 
-static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
-  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
-  uint16_t queryId = 0;
+  const uint16_t queryId = ntohs(dq.dh->id);
 
   try {
-    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
-      return;
-    }
-
     /* we need an accurate ("real") value for the response and
        to store into the IDS, but not for insertion into the
        rings for example */
-    struct timespec queryRealTime;
     struct timespec now;
     gettime(&now);
-    gettime(&queryRealTime, true);
-
-    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-#ifdef HAVE_DNSCRYPT
-    if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) {
-      return;
-    }
-#endif
-
-    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
-    queryId = ntohs(dh->id);
-
-    if (!checkQueryHeaders(dh)) {
-      return;
-    }
 
     string poolname;
-    int delayMsec = 0;
-    const uint16_t * flags = getFlagsFromDNSHeader(dh);
-    const uint16_t origFlags = *flags;
-    uint16_t qtype, qclass;
-    unsigned int consumed = 0;
-    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
-    bool dnssecOK = false;
 
-    if (!processQuery(holders, dq, poolname, &delayMsec, now))
-    {
-      return;
+    if (!applyRulesToQuery(holders, dq, poolname, now)) {
+      return ProcessQueryResult::Drop;
     }
 
     if(dq.dh->qr) { // something turned it into a response
-      fixUpQueryTurnedResponse(dq, origFlags);
-
-      if (!cs.muted) {
-        char* response = query;
-        uint16_t responseLen = dq.len;
-
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
+      fixUpQueryTurnedResponse(dq, dq.origFlags);
 
-        ++g_stats.selfAnswered;
+      if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+        return ProcessQueryResult::Drop;
       }
 
-      return;
+      ++g_stats.selfAnswered;
+      return ProcessQueryResult::SendAnswer;
     }
 
-    DownstreamState* ss = nullptr;
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-    std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+    dq.packetCache = serverPool->packetCache;
     auto policy = *(holders.policy);
     if (serverPool->policy != nullptr) {
       policy = *(serverPool->policy);
@@ -1429,77 +1402,148 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     auto servers = serverPool->getServers();
     if (policy.isLua) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      ss = policy.policy(servers, &dq).get();
+      selectedBackend = policy.policy(servers, &dq);
     }
     else {
-      ss = policy.policy(servers, &dq).get();
+      selectedBackend = policy.policy(servers, &dq);
     }
 
-    bool ednsAdded = false;
-    bool ecsAdded = false;
-    uint32_t cacheKeyNoECS = 0;
-    uint32_t cacheKey = 0;
-    boost::optional<Netmask> subnet;
     uint16_t cachedResponseSize = dq.size;
-    uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
-    bool useZeroScope = false;
+    uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
 
-    if (packetCache && !dq.skipCache) {
-      dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
+    if (dq.packetCache && !dq.skipCache) {
+      dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
     }
 
-    if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
+    if (dq.useECS && ((selectedBackend && selectedBackend->useECS) || (!selectedBackend && serverPool->getECS()))) {
       // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-      if (packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-        if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-          sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-          return;
+      if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
+        if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
+          dq.len = cachedResponseSize;
+
+          if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+            return ProcessQueryResult::Drop;
+          }
+
+          return ProcessQueryResult::SendAnswer;
         }
 
-        if (!subnet) {
+        if (!dq.subnet) {
           /* there was no existing ECS on the query, enable the zero-scope feature */
-          useZeroScope = true;
+          dq.useZeroScope = true;
         }
       }
 
-      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-        vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
-        return;
+      if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
+        vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
+        return ProcessQueryResult::Drop;
       }
     }
 
-    if (packetCache && !dq.skipCache) {
-      if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-        sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-        return;
+    if (dq.packetCache && !dq.skipCache) {
+      if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
+        dq.len = cachedResponseSize;
+
+        if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+          return ProcessQueryResult::Drop;
+        }
+
+        return ProcessQueryResult::SendAnswer;
       }
       ++g_stats.cacheMisses;
     }
 
-    if(!ss) {
+    if(!selectedBackend) {
       ++g_stats.noPolicy;
 
-      if (g_servFailOnNoPolicy && !cs.muted) {
-        char* response = query;
-        uint16_t responseLen = dq.len;
-        restoreFlags(dh, origFlags);
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      if (g_servFailOnNoPolicy) {
+        restoreFlags(dq.dh, dq.origFlags);
 
         dq.dh->rcode = RCode::ServFail;
         dq.dh->qr = true;
 
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
-
+        if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+          return ProcessQueryResult::Drop;
+        }
         // no response-only statistics counter to update.
+        return ProcessQueryResult::SendAnswer;
       }
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
+
+      return ProcessQueryResult::Drop;
+    }
+
+    if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
+      addXPF(dq, selectedBackend->xpfRRCode, g_preserveTrailingData);
+    }
+
+    selectedBackend->queries++;
+    return ProcessQueryResult::PassToBackend;
+  }
+  catch(const std::exception& e){
+    vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
+  }
+  return ProcessQueryResult::Drop;
+}
+
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+{
+  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
+  uint16_t queryId = 0;
+
+  try {
+    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
       return;
     }
 
-    if (dq.addXPF && ss->xpfRRCode != 0) {
-      addXPF(dq, ss->xpfRRCode, g_preserveTrailingData);
+    /* we need an accurate ("real") value for the response and
+       to store into the IDS, but not for insertion into the
+       rings for example */
+    struct timespec queryRealTime;
+    gettime(&queryRealTime, true);
+
+    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
+    auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
+    if (dnsCryptResponse) {
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), 0, dest, remote);
+      return;
     }
 
-    ss->queries++;
+    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+    queryId = ntohs(dh->id);
+
+    if (!checkQueryHeaders(dh)) {
+      return;
+    }
+
+    uint16_t qtype, qclass;
+    unsigned int consumed = 0;
+    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
+    dq.dnsCryptQuery = std::move(dnsCryptQuery);
+    std::shared_ptr<DownstreamState> ss{nullptr};
+    auto result = processQuery(dq, cs, holders, ss);
+
+    if (result == ProcessQueryResult::Drop) {
+      return;
+    }
+
+    if (result == ProcessQueryResult::SendAnswer) {
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+      if (dq.delayMsec == 0 && responsesVect != nullptr) {
+        queueResponse(cs, reinterpret_cast<char*>(dq.dh), dq.len, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+        (*queuedResponses)++;
+        return;
+      }
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+      /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dq.dh), dq.len, dq.delayMsec, dest, *dq.remote);
+      return;
+    }
+
+    if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
+      return;
+    }
 
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
@@ -1517,24 +1561,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     ids->cs = &cs;
     ids->origID = dh->id;
-    ids->origRemote = remote;
-    ids->sentTime.set(queryRealTime);
-    ids->qname = qname;
-    ids->qtype = dq.qtype;
-    ids->qclass = dq.qclass;
-    ids->delayMsec = delayMsec;
-    ids->tempFailureTTL = dq.tempFailureTTL;
-    ids->origFlags = origFlags;
-    ids->cacheKey = cacheKey;
-    ids->cacheKeyNoECS = cacheKeyNoECS;
-    ids->subnet = subnet;
-    ids->skipCache = dq.skipCache;
-    ids->packetCache = packetCache;
-    ids->ednsAdded = ednsAdded;
-    ids->ecsAdded = ecsAdded;
-    ids->useZeroScope = useZeroScope;
-    ids->qTag = dq.qTag;
-    ids->dnssecOK = dnssecOK;
+    setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
     /* If we couldn't harvest the real dest addr, still
        write down the listening addr since it will be useful
@@ -1550,12 +1577,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids->origDest = cs.local;
       ids->destHarvested = false;
     }
-#ifdef HAVE_DNSCRYPT
-    ids->dnsCryptQuery = dnsCryptQuery;
-#endif
-#ifdef HAVE_PROTOBUF
-    ids->uniqueId = dq.uniqueId;
-#endif
 
     dh->id = idOffset;
 
@@ -1631,8 +1652,8 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
       unsigned int got = msgVec[msgIdx].msg_len;
       const ComboAddress& remote = recvData[msgIdx].remote;
 
-      if (got < sizeof(struct dnsheader)) {
-        g_stats.nonCompliantQueries++;
+      if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
+        ++g_stats.nonCompliantQueries;
         continue;
       }
 
@@ -1721,12 +1742,12 @@ uint16_t getRandomDNSID()
 #endif
 }
 
-static bool upCheck(DownstreamState& ds)
+static bool upCheck(const shared_ptr<DownstreamState>& ds)
 try
 {
-  DNSName checkName = ds.checkName;
-  uint16_t checkType = ds.checkType.getCode();
-  uint16_t checkClass = ds.checkClass;
+  DNSName checkName = ds->checkName;
+  uint16_t checkType = ds->checkType.getCode();
+  uint16_t checkClass = ds->checkClass;
   dnsheader checkHeader;
   memset(&checkHeader, 0, sizeof(checkHeader));
 
@@ -1734,13 +1755,13 @@ try
   checkHeader.id = getRandomDNSID();
 
   checkHeader.rd = true;
-  if (ds.setCD) {
+  if (ds->setCD) {
     checkHeader.cd = true;
   }
 
-  if (ds.checkFunction) {
+  if (ds->checkFunction) {
     std::lock_guard<std::mutex> lock(g_luamutex);
-    auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader);
+    auto ret = ds->checkFunction(checkName, checkType, checkClass, &checkHeader);
     checkName = std::get<0>(ret);
     checkType = std::get<1>(ret);
     checkClass = std::get<2>(ret);
@@ -1751,31 +1772,31 @@ try
   dnsheader * requestHeader = dpw.getHeader();
   *requestHeader = checkHeader;
 
-  Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM);
+  Socket sock(ds->remote.sin4.sin_family, SOCK_DGRAM);
   sock.setNonBlocking();
-  if (!IsAnyAddress(ds.sourceAddr)) {
+  if (!IsAnyAddress(ds->sourceAddr)) {
     sock.setReuseAddr();
-    sock.bind(ds.sourceAddr);
+    sock.bind(ds->sourceAddr);
   }
-  sock.connect(ds.remote);
-  ssize_t sent = udpClientSendRequestToBackend(&ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
+  sock.connect(ds->remote);
+  ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), reinterpret_cast<char*>(&packet[0]), packet.size(), true);
   if (sent < 0) {
     int ret = errno;
     if (g_verboseHealthChecks)
-      infolog("Error while sending a health check query to backend %s: %d", ds.getNameWithAddr(), ret);
+      infolog("Error while sending a health check query to backend %s: %d", ds->getNameWithAddr(), ret);
     return false;
   }
 
-  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds.checkTimeout / 1000, /* remaining ms to us */ (ds.checkTimeout % 1000) * 1000);
+  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds->checkTimeout / 1000, /* remaining ms to us */ (ds->checkTimeout % 1000) * 1000);
   if(ret < 0 || !ret) { // error, timeout, both are down!
     if (ret < 0) {
       ret = errno;
       if (g_verboseHealthChecks)
-        infolog("Error while waiting for the health check response from backend %s: %d", ds.getNameWithAddr(), ret);
+        infolog("Error while waiting for the health check response from backend %s: %d", ds->getNameWithAddr(), ret);
     }
     else {
       if (g_verboseHealthChecks)
-        infolog("Timeout while waiting for the health check response from backend %s", ds.getNameWithAddr());
+        infolog("Timeout while waiting for the health check response from backend %s", ds->getNameWithAddr());
     }
     return false;
   }
@@ -1785,9 +1806,9 @@ try
   sock.recvFrom(reply, from);
 
   /* we are using a connected socket but hey.. */
-  if (from != ds.remote) {
+  if (from != ds->remote) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds.remote.toStringWithPort());
+      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds->remote.toStringWithPort());
     return false;
   }
 
@@ -1795,31 +1816,31 @@ try
 
   if (reply.size() < sizeof(*responseHeader)) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds.getNameWithAddr(), sizeof(*responseHeader));
+      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds->getNameWithAddr(), sizeof(*responseHeader));
     return false;
   }
 
   if (responseHeader->id != requestHeader->id) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds.getNameWithAddr(), requestHeader->id);
+      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds->getNameWithAddr(), requestHeader->id);
     return false;
   }
 
   if (!responseHeader->qr) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response from backend %s, expecting QR to be set", ds.getNameWithAddr());
+      infolog("Invalid health check response from backend %s, expecting QR to be set", ds->getNameWithAddr());
     return false;
   }
 
   if (responseHeader->rcode == RCode::ServFail) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with ServFail", ds.getNameWithAddr());
+      infolog("Backend %s responded to health check with ServFail", ds->getNameWithAddr());
     return false;
   }
 
-  if (ds.mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
+  if (ds->mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with %s while mustResolve is set", ds.getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
+      infolog("Backend %s responded to health check with %s while mustResolve is set", ds->getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
     return false;
   }
 
@@ -1829,7 +1850,7 @@ try
 
   if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
+      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds->getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
     return false;
   }
 
@@ -1838,13 +1859,13 @@ try
 catch(const std::exception& e)
 {
   if (g_verboseHealthChecks)
-    infolog("Error checking the health of backend %s: %s", ds.getNameWithAddr(), e.what());
+    infolog("Error checking the health of backend %s: %s", ds->getNameWithAddr(), e.what());
   return false;
 }
 catch(...)
 {
   if (g_verboseHealthChecks)
-    infolog("Unknown exception while checking the health of backend %s", ds.getNameWithAddr());
+    infolog("Unknown exception while checking the health of backend %s", ds->getNameWithAddr());
   return false;
 }
 
@@ -1958,15 +1979,15 @@ static void healthChecksThread()
         continue;
       dss->lastCheck = 0;
       if(dss->availability==DownstreamState::Availability::Auto) {
-        bool newState=upCheck(*dss);
+        bool newState=upCheck(dss);
         if (newState) {
           /* check succeeded */
           dss->currentCheckFailures = 0;
 
           if (!dss->upStatus) {
             /* we were marked as down */
-            dss->consecutiveSuccesfulChecks++;
-            if (dss->consecutiveSuccesfulChecks < dss->minRiseSuccesses) {
+            dss->consecutiveSuccessfulChecks++;
+            if (dss->consecutiveSuccessfulChecks < dss->minRiseSuccesses) {
               /* if we need more than one successful check to rise
                  and we didn't reach the threshold yet,
                  let's stay down */
@@ -1976,7 +1997,7 @@ static void healthChecksThread()
         }
         else {
           /* check failed */
-          dss->consecutiveSuccesfulChecks = 0;
+          dss->consecutiveSuccessfulChecks = 0;
 
           if (dss->upStatus) {
             /* we are currently up */
@@ -2002,7 +2023,7 @@ static void healthChecksThread()
 
           dss->upStatus = newState;
           dss->currentCheckFailures = 0;
-          dss->consecutiveSuccesfulChecks = 0;
+          dss->consecutiveSuccessfulChecks = 0;
           if (g_snmpAgent && g_snmpTrapsEnabled) {
             g_snmpAgent->sendBackendStatusChangeTrap(dss);
           }
@@ -2147,6 +2168,102 @@ static void checkFileDescriptorsLimits(size_t udpBindsCount, size_t tcpBindsCoun
   }
 }
 
+static void setUpLocalBind(std::unique_ptr<ClientState>& cs)
+{
+  /* skip some warnings if there is an identical UDP context */
+  bool warn = cs->tcp == false || cs->tlsFrontend != nullptr;
+  int& fd = cs->tcp == false ? cs->udpFD : cs->tcpFD;
+  (void) warn;
+
+  fd = SSocket(cs->local.sin4.sin_family, cs->tcp == false ? SOCK_DGRAM : SOCK_STREAM, 0);
+
+  if (cs->tcp) {
+    SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
+#ifdef TCP_DEFER_ACCEPT
+    SSetsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
+#endif
+    if (cs->fastOpenQueueSize > 0) {
+#ifdef TCP_FASTOPEN
+      SSetsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN, cs->fastOpenQueueSize);
+#else
+      if (warn) {
+        warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
+      }
+#endif
+    }
+  }
+
+  if(cs->local.sin4.sin_family == AF_INET6) {
+    SSetsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, 1);
+  }
+
+  bindAny(cs->local.sin4.sin_family, fd);
+
+  if(!cs->tcp && IsAnyAddress(cs->local)) {
+    int one=1;
+    setsockopt(fd, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one));     // linux supports this, so why not - might fail on other systems
+#ifdef IPV6_RECVPKTINFO
+    setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
+#endif
+  }
+
+  if (cs->reuseport) {
+#ifdef SO_REUSEPORT
+    SSetsockopt(fd, SOL_SOCKET, SO_REUSEPORT, 1);
+#else
+    if (warn) {
+      /* no need to warn again if configured but support is not available, we already did for UDP */
+      warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
+    }
+#endif
+  }
+
+  const std::string& itf = cs->interface;
+  if (!itf.empty()) {
+#ifdef SO_BINDTODEVICE
+    int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
+    if (res != 0) {
+      warnlog("Error setting up the interface on local address '%s': %s", cs->local.toStringWithPort(), strerror(errno));
+    }
+#else
+    if (warn) {
+      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", cs->local.toStringWithPort());
+    }
+#endif
+  }
+
+#ifdef HAVE_EBPF
+  if (g_defaultBPFFilter) {
+    cs->attachFilter(g_defaultBPFFilter);
+    vinfolog("Attaching default BPF Filter to %s frontend %s", (!cs->tcp ? "UDP" : "TCP"), cs->local.toStringWithPort());
+  }
+#endif /* HAVE_EBPF */
+
+  if (cs->tlsFrontend != nullptr) {
+    if (!cs->tlsFrontend->setupTLS()) {
+      errlog("Error while setting up TLS on local address '%s', exiting", cs->local.toStringWithPort());
+      _exit(EXIT_FAILURE);
+    }
+  }
+
+  SBind(fd, cs->local);
+
+  if (cs->tcp) {
+    SListen(cs->tcpFD, 64);
+    if (cs->tlsFrontend != nullptr) {
+      warnlog("Listening on %s for TLS", cs->local.toStringWithPort());
+    }
+    else if (cs->dnscryptCtx != nullptr) {
+      warnlog("Listening on %s for DNSCrypt", cs->local.toStringWithPort());
+    }
+    else {
+      warnlog("Listening on %s", cs->local.toStringWithPort());
+    }
+  }
+
+  cs->ready = true;
+}
+
 struct 
 {
   vector<string> locals;
@@ -2206,7 +2323,7 @@ try
 
   signal(SIGPIPE, SIG_IGN);
   signal(SIGCHLD, SIG_IGN);
-  openlog("dnsdist", LOG_PID, LOG_DAEMON);
+  openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON);
 
 #ifdef HAVE_LIBSODIUM
   if (sodium_init() == -1) {
@@ -2330,6 +2447,9 @@ try
 #ifdef HAVE_FSTRM
       cout<<"fstrm ";
 #endif
+#ifdef HAVE_LIBCRYPTO
+      cout<<"ipcipher ";
+#endif
 #ifdef HAVE_LIBSODIUM
       cout<<"libsodium ";
 #endif
@@ -2425,293 +2545,42 @@ try
     }
   }
 
-  if(g_cmdLine.locals.size()) {
-    g_locals.clear();
-    for(auto loc : g_cmdLine.locals)
-      g_locals.push_back(std::make_tuple(ComboAddress(loc, 53), true, false, 0, "", std::set<int>()));
-  }
-  
-  if(g_locals.empty())
-    g_locals.push_back(std::make_tuple(ComboAddress("127.0.0.1", 53), true, false, 0, "", std::set<int>()));
-
-  g_configurationDone = true;
-
-  vector<ClientState*> toLaunch;
-  for(const auto& local : g_locals) {
-    ClientState* cs = new ClientState;
-    cs->local= std::get<0>(local);
-    cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
-    if(cs->local.sin4.sin_family == AF_INET6) {
-      SSetsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
-    }
-    //if(g_vm.count("bind-non-local"))
-    bindAny(cs->local.sin4.sin_family, cs->udpFD);
-
-    //    if (!setSocketTimestamps(cs->udpFD))
-    //      g_log<<Logger::Warning<<"Unable to enable timestamp reporting for socket"<<endl;
-
-
-    if(IsAnyAddress(cs->local)) {
-      int one=1;
-      setsockopt(cs->udpFD, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one));     // linux supports this, so why not - might fail on other systems
-#ifdef IPV6_RECVPKTINFO
-      setsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
-#endif
-    }
-
-    if (std::get<2>(local)) {
-#ifdef SO_REUSEPORT
-      SSetsockopt(cs->udpFD, SOL_SOCKET, SO_REUSEPORT, 1);
-#else
-      warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", std::get<0>(local).toStringWithPort());
-#endif
-    }
-
-    const std::string& itf = std::get<4>(local);
-    if (!itf.empty()) {
-#ifdef SO_BINDTODEVICE
-      int res = setsockopt(cs->udpFD, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
-      if (res != 0) {
-        warnlog("Error setting up the interface on local address '%s': %s", std::get<0>(local).toStringWithPort(), strerror(errno));
+  if (!g_cmdLine.locals.empty()) {
+    for (auto it = g_frontends.begin(); it != g_frontends.end(); ) {
+      /* TLS and DNSCrypt frontends are separate */
+      if ((*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr) {
+        it = g_frontends.erase(it);
       }
-#else
-      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", std::get<0>(local).toStringWithPort());
-#endif
-    }
-
-#ifdef HAVE_EBPF
-    if (g_defaultBPFFilter) {
-      cs->attachFilter(g_defaultBPFFilter);
-      vinfolog("Attaching default BPF Filter to UDP frontend %s", cs->local.toStringWithPort());
-    }
-#endif /* HAVE_EBPF */
-
-    cs->cpus = std::get<5>(local);
-
-    SBind(cs->udpFD, cs->local);
-    toLaunch.push_back(cs);
-    g_frontends.push_back(cs);
-    udpBindsCount++;
-  }
-
-  for(const auto& local : g_locals) {
-    if(!std::get<1>(local)) { // no TCP/IP
-      warnlog("Not providing TCP/IP service on local address '%s'", std::get<0>(local).toStringWithPort());
-      continue;
-    }
-    ClientState* cs = new ClientState;
-    cs->local= std::get<0>(local);
-
-    cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0);
-
-    SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef TCP_DEFER_ACCEPT
-    SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
-#endif
-    if (std::get<3>(local) > 0) {
-#ifdef TCP_FASTOPEN
-      SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_FASTOPEN, std::get<3>(local));
-#else
-      warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", std::get<0>(local).toStringWithPort());
-#endif
-    }
-    if(cs->local.sin4.sin_family == AF_INET6) {
-      SSetsockopt(cs->tcpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
-    }
-#ifdef SO_REUSEPORT
-    /* no need to warn again if configured but support is not available, we already did for UDP */
-    if (std::get<2>(local)) {
-      SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEPORT, 1);
-    }
-#endif
-
-    const std::string& itf = std::get<4>(local);
-    if (!itf.empty()) {
-#ifdef SO_BINDTODEVICE
-      int res = setsockopt(cs->tcpFD, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
-      if (res != 0) {
-        warnlog("Error setting up the interface on local address '%s': %s", std::get<0>(local).toStringWithPort(), strerror(errno));
+      else {
+        ++it;
       }
-#else
-      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", std::get<0>(local).toStringWithPort());
-#endif
     }
 
-#ifdef HAVE_EBPF
-    if (g_defaultBPFFilter) {
-      cs->attachFilter(g_defaultBPFFilter);
-      vinfolog("Attaching default BPF Filter to TCP frontend %s", cs->local.toStringWithPort());
+    for(const auto& loc : g_cmdLine.locals) {
+      /* UDP */
+      g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), false, false, 0, "", {})));
+      /* TCP */
+      g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), true, false, 0, "", {})));
     }
-#endif /* HAVE_EBPF */
-
-    //    if(g_vm.count("bind-non-local"))
-      bindAny(cs->local.sin4.sin_family, cs->tcpFD);
-    SBind(cs->tcpFD, cs->local);
-    SListen(cs->tcpFD, 64);
-    warnlog("Listening on %s", cs->local.toStringWithPort());
-
-    toLaunch.push_back(cs);
-    g_frontends.push_back(cs);
-    tcpBindsCount++;
   }
 
-#ifdef HAVE_DNSCRYPT
-  for(auto& dcLocal : g_dnsCryptLocals) {
-    ClientState* cs = new ClientState;
-    cs->local = std::get<0>(dcLocal);
-    cs->dnscryptCtx = std::get<1>(dcLocal);
-    cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
-    if(cs->local.sin4.sin_family == AF_INET6) {
-      SSetsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
-    }
-    bindAny(cs->local.sin4.sin_family, cs->udpFD);
-    if(IsAnyAddress(cs->local)) {
-      int one=1;
-      setsockopt(cs->udpFD, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one));     // linux supports this, so why not - might fail on other systems
-#ifdef IPV6_RECVPKTINFO
-      setsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)); 
-#endif
-    }
-    if (std::get<2>(dcLocal)) {
-#ifdef SO_REUSEPORT
-      SSetsockopt(cs->udpFD, SOL_SOCKET, SO_REUSEPORT, 1);
-#else
-      warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", std::get<0>(dcLocal).toStringWithPort());
-#endif
-    }
-
-    const std::string& itf = std::get<4>(dcLocal);
-    if (!itf.empty()) {
-#ifdef SO_BINDTODEVICE
-      int res = setsockopt(cs->udpFD, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
-      if (res != 0) {
-        warnlog("Error setting up the interface on local address '%s': %s", std::get<0>(dcLocal).toStringWithPort(), strerror(errno));
-      }
-#else
-      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", std::get<0>(dcLocal).toStringWithPort());
-#endif
-    }
-
-#ifdef HAVE_EBPF
-    if (g_defaultBPFFilter) {
-      cs->attachFilter(g_defaultBPFFilter);
-      vinfolog("Attaching default BPF Filter to UDP DNSCrypt frontend %s", cs->local.toStringWithPort());
-    }
-#endif /* HAVE_EBPF */
-    SBind(cs->udpFD, cs->local);    
-    toLaunch.push_back(cs);
-    g_frontends.push_back(cs);
-    udpBindsCount++;
-
-    cs = new ClientState;
-    cs->local = std::get<0>(dcLocal);
-    cs->dnscryptCtx = std::get<1>(dcLocal);
-    cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0);
-    SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef TCP_DEFER_ACCEPT
-    SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
-#endif
-    if (std::get<3>(dcLocal) > 0) {
-#ifdef TCP_FASTOPEN
-      SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_FASTOPEN, std::get<3>(dcLocal));
-#else
-      warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", std::get<0>(dcLocal).toStringWithPort());
-#endif
-    }
-
-#ifdef SO_REUSEPORT
-    /* no need to warn again if configured but support is not available, we already did for UDP */
-    if (std::get<2>(dcLocal)) {
-      SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEPORT, 1);
-    }
-#endif
-
-    if (!itf.empty()) {
-#ifdef SO_BINDTODEVICE
-      int res = setsockopt(cs->tcpFD, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
-      if (res != 0) {
-        warnlog("Error setting up the interface on local address '%s': %s", std::get<0>(dcLocal).toStringWithPort(), strerror(errno));
-      }
-#else
-      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", std::get<0>(dcLocal).toStringWithPort());
-#endif
-    }
-
-    if(cs->local.sin4.sin_family == AF_INET6) {
-      SSetsockopt(cs->tcpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
-    }
-#ifdef HAVE_EBPF
-    if (g_defaultBPFFilter) {
-      cs->attachFilter(g_defaultBPFFilter);
-      vinfolog("Attaching default BPF Filter to TCP DNSCrypt frontend %s", cs->local.toStringWithPort());
-    }
-#endif /* HAVE_EBPF */
-
-    cs->cpus = std::get<5>(dcLocal);
-
-    bindAny(cs->local.sin4.sin_family, cs->tcpFD);
-    SBind(cs->tcpFD, cs->local);
-    SListen(cs->tcpFD, 64);
-    warnlog("Listening on %s", cs->local.toStringWithPort());
-    toLaunch.push_back(cs);
-    g_frontends.push_back(cs);
-    tcpBindsCount++;
+  if (g_frontends.empty()) {
+    /* UDP */
+    g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), false, false, 0, "", {})));
+    /* TCP */
+    g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), true, false, 0, "", {})));
   }
-#endif
-
-  for(auto& frontend : g_tlslocals) {
-    ClientState* cs = new ClientState;
-    cs->local = frontend->d_addr;
-    cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0);
-    SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef TCP_DEFER_ACCEPT
-    SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
-#endif
-    if (frontend->d_tcpFastOpenQueueSize > 0) {
-#ifdef TCP_FASTOPEN
-      SSetsockopt(cs->tcpFD, IPPROTO_TCP, TCP_FASTOPEN, frontend->d_tcpFastOpenQueueSize);
-#else
-      warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
-#endif
-    }
-    if (frontend->d_reusePort) {
-#ifdef SO_REUSEPORT
-      SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEPORT, 1);
-#else
-      warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
-#endif
-    }
-    if(cs->local.sin4.sin_family == AF_INET6) {
-      SSetsockopt(cs->tcpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
-    }
 
-    if (!frontend->d_interface.empty()) {
-#ifdef SO_BINDTODEVICE
-      int res = setsockopt(cs->tcpFD, SOL_SOCKET, SO_BINDTODEVICE, frontend->d_interface.c_str(), frontend->d_interface.length());
-      if (res != 0) {
-        warnlog("Error setting up the interface on local address '%s': %s", cs->local.toStringWithPort(), strerror(errno));
-      }
-#else
-      warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", cs->local.toStringWithPort());
-#endif
-    }
+  g_configurationDone = true;
 
-    cs->cpus = frontend->d_cpus;
+  for(auto& frontend : g_frontends) {
+    setUpLocalBind(frontend);
 
-    bindAny(cs->local.sin4.sin_family, cs->tcpFD);
-    if (frontend->setupTLS()) {
-      cs->tlsFrontend = frontend;
-      SBind(cs->tcpFD, cs->local);
-      SListen(cs->tcpFD, 64);
-      warnlog("Listening on %s for TLS", cs->local.toStringWithPort());
-      toLaunch.push_back(cs);
-      g_frontends.push_back(cs);
-      tcpBindsCount++;
+    if (frontend->tcp == false) {
+      ++udpBindsCount;
     }
     else {
-      errlog("Error while setting up TLS on local address '%s', exiting", cs->local.toStringWithPort());
-      delete cs;
-      _exit(EXIT_FAILURE);
+      ++tcpBindsCount;
     }
   }
 
@@ -2774,7 +2643,7 @@ try
     g_snmpAgent->run();
   }
 
-  g_tcpclientthreads = std::make_shared<TCPClientCollection>(g_maxTCPClientThreads, g_useTCPSinglePipe);
+  g_tcpclientthreads = std::unique_ptr<TCPClientCollection>(new TCPClientCollection(g_maxTCPClientThreads, g_useTCPSinglePipe));
 
   for(auto& t : todo)
     t();
@@ -2803,22 +2672,22 @@ try
 
   for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
     if(dss->availability==DownstreamState::Availability::Auto) {
-      bool newState=upCheck(*dss);
+      bool newState=upCheck(dss);
       warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
       dss->upStatus = newState;
     }
   }
 
-  for(auto& cs : toLaunch) {
+  for(auto& cs : g_frontends) {
     if (cs->udpFD >= 0) {
-      thread t1(udpClientThread, cs);
+      thread t1(udpClientThread, cs.get());
       if (!cs->cpus.empty()) {
         mapThreadToCPUList(t1.native_handle(), cs->cpus);
       }
       t1.detach();
     }
     else if (cs->tcpFD >= 0) {
-      thread t1(tcpAcceptorThread, cs);
+      thread t1(tcpAcceptorThread, cs.get());
       if (!cs->cpus.empty()) {
         mapThreadToCPUList(t1.native_handle(), cs->cpus);
       }
index 925d1d0d1d639218a1c44339b06d6abbf72e13a4..71651495fa54f218a4a2ca2c2e51927fdbf6950a 100644 (file)
@@ -61,39 +61,59 @@ typedef std::unordered_map<string, string> QTag;
 struct DNSQuestion
 {
   DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_):
-    qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), consumed(consumed_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tempFailureTTL(boost::none), tcp(isTcp), queryTime(queryTime_), ecsOverride(g_ECSOverride) { }
+    qname(name), local(lc), remote(rem), dh(header), queryTime(queryTime_), size(bufferSize), consumed(consumed_), tempFailureTTL(boost::none), qtype(type), qclass(class_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) {
+    const uint16_t* flags = getFlagsFromDNSHeader(dh);
+    origFlags = *flags;
+  }
+  DNSQuestion(const DNSQuestion&) = delete;
+  DNSQuestion& operator=(const DNSQuestion&) = delete;
+  DNSQuestion(DNSQuestion&&) = default;
 
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
   Netmask ecs;
-  const DNSName* qname;
-  const uint16_t qtype;
-  const uint16_t qclass;
-  const ComboAddress* local;
-  const ComboAddress* remote;
+  boost::optional<Netmask> subnet;
+  const DNSName* qname{nullptr};
+  const ComboAddress* local{nullptr};
+  const ComboAddress* remote{nullptr};
   std::shared_ptr<QTag> qTag{nullptr};
   std::shared_ptr<std::map<uint16_t, EDNSOptionView> > ednsOptions;
-  struct dnsheader* dh;
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+  struct dnsheader* dh{nullptr};
+  const struct timespec* queryTime{nullptr};
   size_t size;
   unsigned int consumed{0};
+  int delayMsec{0};
+  boost::optional<uint32_t> tempFailureTTL;
+  uint32_t cacheKeyNoECS;
+  uint32_t cacheKey;
+  const uint16_t qtype;
+  const uint16_t qclass;
   uint16_t len;
   uint16_t ecsPrefixLength;
+  uint16_t origFlags;
   uint8_t ednsRCode{0};
-  boost::optional<uint32_t> tempFailureTTL;
   const bool tcp;
-  const struct timespec* queryTime;
   bool skipCache{false};
   bool ecsOverride;
   bool useECS{true};
   bool addXPF{true};
   bool ecsSet{false};
+  bool ecsAdded{false};
+  bool ednsAdded{false};
+  bool useZeroScope{false};
+  bool dnssecOK{false};
 };
 
 struct DNSResponse : DNSQuestion
 {
   DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_):
     DNSQuestion(name, type, class_, consumed, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { }
+  DNSResponse(const DNSResponse&) = delete;
+  DNSResponse& operator=(const DNSResponse&) = delete;
+  DNSResponse(DNSResponse&&) = default;
 };
 
 /* so what could you do:
@@ -208,6 +228,9 @@ struct DNSDistStats
   stat_t responses{0};
   stat_t servfailResponses{0};
   stat_t queries{0};
+  stat_t frontendNXDomain{0};
+  stat_t frontendServFail{0};
+  stat_t frontendNoError{0};
   stat_t nonCompliantQueries{0};
   stat_t nonCompliantResponses{0};
   stat_t rdQueries{0};
@@ -235,6 +258,9 @@ struct DNSDistStats
     {"responses", &responses},
     {"servfail-responses", &servfailResponses},
     {"queries", &queries},
+    {"frontend-nxdomain", &frontendNXDomain},
+    {"frontend-servfail", &frontendServFail},
+    {"frontend-noerror", &frontendNoError},
     {"acl-drops", &aclDrops},
     {"rule-drop", &ruleDrop},
     {"rule-nxdomain", &ruleNXDomain},
@@ -324,6 +350,9 @@ struct MetricDefinitionStorage {
     { "responses",              MetricDefinition(PrometheusMetricType::counter, "Number of responses received from backends") },
     { "servfail-responses",     MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers received from backends") },
     { "queries",                MetricDefinition(PrometheusMetricType::counter, "Number of received queries")},
+    { "frontend-nxdomain",      MetricDefinition(PrometheusMetricType::counter, "Number of NXDomain answers sent to clients")},
+    { "frontend-servfail",      MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers sent to clients")},
+    { "frontend-noerror",       MetricDefinition(PrometheusMetricType::counter, "Number of NoError answers sent to clients")},
     { "acl-drops",              MetricDefinition(PrometheusMetricType::counter, "Number of packets dropped because of the ACL")},
     { "rule-drop",              MetricDefinition(PrometheusMetricType::counter, "Number of queries dropped because of a rule")},
     { "rule-nxdomain",          MetricDefinition(PrometheusMetricType::counter, "Number of NXDomain answers returned because of a rule")},
@@ -516,9 +545,7 @@ struct IDState
   ComboAddress origDest;                                      // 28
   StopWatch sentTime;                                         // 16
   DNSName qname;                                              // 80
-#ifdef HAVE_DNSCRYPT
   std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
-#endif
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
@@ -544,7 +571,7 @@ struct IDState
 };
 
 typedef std::unordered_map<string, unsigned int> QueryCountRecords;
-typedef std::function<std::tuple<bool, string>(DNSQuestion dq)> QueryCountFilter;
+typedef std::function<std::tuple<bool, string>(const DNSQuestion* dq)> QueryCountFilter;
 struct QueryCount {
   QueryCount()
   {
@@ -560,22 +587,52 @@ extern QueryCount g_qcount;
 
 struct ClientState
 {
+  ClientState(const ComboAddress& local_, bool isTCP, bool doReusePort, int fastOpenQueue, const std::string& itfName, const std::set<int>& cpus_): cpus(cpus_), local(local_), interface(itfName), fastOpenQueueSize(fastOpenQueue), tcp(isTCP), reuseport(doReusePort)
+  {
+  }
+
   std::set<int> cpus;
   ComboAddress local;
-#ifdef HAVE_DNSCRYPT
   std::shared_ptr<DNSCryptContext> dnscryptCtx{nullptr};
-#endif
-  shared_ptr<TLSFrontend> tlsFrontend;
+  std::shared_ptr<TLSFrontend> tlsFrontend{nullptr};
+  std::string interface;
   std::atomic<uint64_t> queries{0};
+  std::atomic<uint64_t> tcpDiedReadingQuery{0};
+  std::atomic<uint64_t> tcpDiedSendingResponse{0};
+  std::atomic<uint64_t> tcpGaveUp{0};
+  std::atomic<uint64_t> tcpClientTimeouts{0};
+  std::atomic<uint64_t> tcpDownstreamTimeouts{0};
+  std::atomic<uint64_t> tcpCurrentConnections{0};
+  std::atomic<double> tcpAvgQueriesPerConnection{0.0};
+  /* in ms */
+  std::atomic<double> tcpAvgConnectionDuration{0.0};
   int udpFD{-1};
   int tcpFD{-1};
+  int fastOpenQueueSize{0};
   bool muted{false};
+  bool tcp;
+  bool reuseport;
+  bool ready{false};
 
   int getSocket() const
   {
     return udpFD != -1 ? udpFD : tcpFD;
   }
 
+  std::string getType() const
+  {
+    std::string result = udpFD != -1 ? "UDP" : "TCP";
+
+    if (tlsFrontend) {
+      result += " (DNS over TLS)";
+    }
+    else if (dnscryptCtx) {
+      result += " (DNSCrypt)";
+    }
+
+    return result;
+  }
+
 #ifdef HAVE_EBPF
   shared_ptr<BPFFilter> d_filter;
 
@@ -595,6 +652,12 @@ struct ClientState
     d_filter = bpf;
   }
 #endif /* HAVE_EBPF */
+
+  void updateTCPMetrics(size_t queries, uint64_t durationMs)
+  {
+    tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (queries / 100.0);
+    tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
+  }
 };
 
 class TCPClientCollection {
@@ -617,6 +680,14 @@ public:
       if (pipe(d_singlePipe) < 0) {
         throw std::runtime_error("Error creating the TCP single communication pipe: " + string(strerror(errno)));
       }
+
+      if (!setNonBlocking(d_singlePipe[0])) {
+        int err = errno;
+        close(d_singlePipe[0]);
+        close(d_singlePipe[1]);
+        throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + string(strerror(err)));
+      }
+
       if (!setNonBlocking(d_singlePipe[1])) {
         int err = errno;
         close(d_singlePipe[0]);
@@ -650,7 +721,7 @@ public:
   void addTCPClientThread();
 };
 
-extern std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+extern std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
 
 struct DownstreamState
 {
@@ -693,6 +764,15 @@ struct DownstreamState
     std::atomic<uint64_t> reuseds{0};
     std::atomic<uint64_t> queries{0};
   } prev;
+  std::atomic<uint64_t> tcpDiedSendingQuery{0};
+  std::atomic<uint64_t> tcpDiedReadingResponse{0};
+  std::atomic<uint64_t> tcpGaveUp{0};
+  std::atomic<uint64_t> tcpReadTimeouts{0};
+  std::atomic<uint64_t> tcpWriteTimeouts{0};
+  std::atomic<uint64_t> tcpCurrentConnections{0};
+  std::atomic<double> tcpAvgQueriesPerConnection{0.0};
+  /* in ms */
+  std::atomic<double> tcpAvgConnectionDuration{0.0};
   string name;
   size_t socketsOffset{0};
   double queryLoad{0.0};
@@ -710,7 +790,7 @@ struct DownstreamState
   uint16_t xpfRRCode{0};
   uint16_t checkTimeout{1000}; /* in milliseconds */
   uint8_t currentCheckFailures{0};
-  uint8_t consecutiveSuccesfulChecks{0};
+  uint8_t consecutiveSuccessfulChecks{0};
   uint8_t maxCheckFailures{1};
   uint8_t minRiseSuccesses{1};
   StopWatch sw;
@@ -764,6 +844,12 @@ struct DownstreamState
   void hash();
   void setId(const boost::uuids::uuid& newId);
   void setWeight(int newWeight);
+
+  void updateTCPMetrics(size_t queries, uint64_t durationMs)
+  {
+    tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (queries / 100.0);
+    tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
+  }
 };
 using servers_t =vector<std::shared_ptr<DownstreamState>>;
 
@@ -937,7 +1023,7 @@ extern ComboAddress g_serverControl; // not changed during runtime
 
 extern std::vector<std::tuple<ComboAddress, bool, bool, int, std::string, std::set<int>>> g_locals; // not changed at runtime (we hope XXX)
 extern std::vector<shared_ptr<TLSFrontend>> g_tlslocals;
-extern vector<ClientState*> g_frontends;
+extern std::vector<std::unique_ptr<ClientState>> g_frontends;
 extern bool g_truncateTC;
 extern bool g_fixupCase;
 extern int g_tcpRecvTimeout;
@@ -959,10 +1045,11 @@ extern std::string g_apiConfigDirectory;
 extern bool g_servFailOnNoPolicy;
 extern uint32_t g_hashperturb;
 extern bool g_useTCPSinglePipe;
-extern std::atomic<uint16_t> g_downstreamTCPCleanupInterval;
+extern uint16_t g_downstreamTCPCleanupInterval;
 extern size_t g_udpVectorSize;
 extern bool g_preserveTrailingData;
 extern bool g_allowEmptyResponse;
+extern bool g_roundrobinFailOnNoServer;
 
 #ifdef HAVE_EBPF
 extern shared_ptr<BPFFilter> g_defaultBPFFilter;
@@ -1023,19 +1110,13 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed);
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now);
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec);
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags);
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope);
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
-bool checkQueryHeaders(const struct dnsheader* dh);
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted);
 
-#ifdef HAVE_DNSCRYPT
-extern std::vector<std::tuple<ComboAddress, std::shared_ptr<DNSCryptContext>, bool, int, std::string, std::set<int> > > g_dnsCryptLocals;
+bool checkQueryHeaders(const struct dnsheader* dh);
 
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
+extern std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
 int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response);
-#endif
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp);
 
 bool addXPF(DNSQuestion& dq, uint16_t optionCode);
 
@@ -1049,3 +1130,9 @@ extern DNSDistSNMPAgent* g_snmpAgent;
 extern bool g_addEDNSToSelfGeneratedResponses;
 
 static const size_t s_udpIncomingBufferSize{1500};
+
+enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend };
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
+
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP);
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname);
index ceeb6d04048aef9f151cb51b60be921855be2a60..3c438681454a1f657f410cf7d41b5ae88045851d 100644 (file)
@@ -2,7 +2,8 @@ AM_CPPFLAGS += $(SYSTEMD_CFLAGS) $(LUA_CFLAGS) $(LIBEDIT_CFLAGS) $(LIBSODIUM_CFL
 
 ACLOCAL_AMFLAGS = -I m4
 
-SUBDIRS=ext/yahttp
+SUBDIRS=ext/ipcrypt \
+       ext/yahttp
 
 CLEANFILES = \
        dnsmessage.pb.cc \
@@ -44,6 +45,10 @@ AM_CPPFLAGS += $(GNUTLS_CFLAGS)
 endif
 endif
 
+if HAVE_LIBCRYPTO
+AM_CPPFLAGS += $(LIBCRYPTO_INCLUDES)
+endif
+
 EXTRA_DIST=COPYING \
           dnslabeltext.rl \
           dnsdistconf.lua \
@@ -98,11 +103,13 @@ dnsdist_SOURCES = \
        dnsdist-dnscrypt.cc \
        dnsdist-dynblocks.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnsdist-idstate.cc \
        dnsdist-lua.hh dnsdist-lua.cc \
        dnsdist-lua-actions.cc \
        dnsdist-lua-bindings.cc \
        dnsdist-lua-bindings-dnsquestion.cc \
        dnsdist-lua-inspection.cc \
+       dnsdist-lua-inspection-ffi.cc dnsdist-lua-inspection-ffi.hh \
        dnsdist-lua-rules.cc \
        dnsdist-lua-vars.cc \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
@@ -165,19 +172,25 @@ dnsdist_LDADD = \
        $(SANITIZER_FLAGS) \
        $(SYSTEMD_LIBS) \
        $(NET_SNMP_LIBS) \
-       $(LIBCAP_LIBS)
+       $(LIBCAP_LIBS) \
+       $(IPCRYPT_LIBS)
 
 if HAVE_RE2
 dnsdist_LDADD += $(RE2_LIBS)
 endif
 
+if HAVE_LIBCRYPTO
+dnsdist_LDADD += $(LIBCRYPTO_LIBS)
+dnsdist_SOURCES += ipcipher.cc ipcipher.hh
+endif
+
 if HAVE_DNS_OVER_TLS
 if HAVE_GNUTLS
 dnsdist_LDADD += -lgnutls
 endif
 
 if HAVE_LIBSSL
-dnsdist_LDADD += $(LIBSSL_LIBS) $(LIBCRYPTO_LIBS)
+dnsdist_LDADD += $(LIBSSL_LIBS)
 endif
 endif
 
@@ -204,20 +217,6 @@ dnsdist.$(OBJEXT): dnsmessage.pb.cc dnstap.pb.cc
 endif
 endif
 
-if HAVE_FREEBSD
-dnsdist_SOURCES += kqueuemplexer.cc
-endif
-
-if HAVE_LINUX
-dnsdist_SOURCES += epollmplexer.cc
-endif
-
-if HAVE_SOLARIS
-dnsdist_SOURCES += \
-        devpollmplexer.cc \
-        portsmplexer.cc
-endif
-
 testrunner_SOURCES = \
        base64.hh \
        dns.hh \
@@ -231,6 +230,7 @@ testrunner_SOURCES = \
        test-dnsdistrules_cc.cc \
        test-dnsparser_cc.cc \
        test-iputils_hh.cc \
+       test-mplexer.cc \
        cachecleaner.hh \
        dnsdist.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
@@ -250,14 +250,35 @@ testrunner_SOURCES = \
        misc.cc misc.hh \
        namespaces.hh \
        pdnsexception.hh \
+       pollmplexer.cc \
        qtype.cc qtype.hh \
        sholder.hh \
        sodcrypto.cc \
        sstuff.hh \
+       statnode.cc statnode.hh \
        threadname.hh threadname.cc \
        testrunner.cc \
        xpf.cc xpf.hh
 
+if HAVE_FREEBSD
+dnsdist_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
+endif
+
+if HAVE_LINUX
+dnsdist_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
+endif
+
+if HAVE_SOLARIS
+dnsdist_SOURCES += \
+        devpollmplexer.cc \
+        portsmplexer.cc
+testrunner_SOURCES += \
+        devpollmplexer.cc \
+        portsmplexer.cc
+endif
+
 testrunner_LDFLAGS = \
        $(AM_LDFLAGS) \
        $(PROGRAM_LDFLAGS) \
index 1013e0d437682a6fe0f703f98b883440ea66df0c..7d8014c955abcc74a910a64bf9a4a7aceb031c48 100644 (file)
@@ -35,8 +35,7 @@ PDNS_CHECK_SECURE_MEMSET
 
 PDNS_WITH_PROTOBUF
 
-boost_required_version=1.42
-BOOST_REQUIRE([$boost_required_version])
+BOOST_REQUIRE([1.42])
 
 PDNS_ENABLE_UNIT_TESTS
 PDNS_WITH_RE2
@@ -51,20 +50,27 @@ AM_CONDITIONAL([HAVE_SYSTEMD], [ test x"$systemd" = "xy" ])
 
 AC_SUBST([YAHTTP_CFLAGS], ['-I$(top_srcdir)/ext/yahttp'])
 AC_SUBST([YAHTTP_LIBS], ['$(top_builddir)/ext/yahttp/yahttp/libyahttp.la'])
+AC_SUBST([IPCRYPT_CFLAGS], ['-I$(top_srcdir)/ext/ipcrypt'])
+AC_SUBST([IPCRYPT_LIBS], ['$(top_builddir)/ext/ipcrypt/libipcrypt.la'])
 
 PDNS_WITH_LUA([mandatory])
+AS_IF([test "x$LUAPC" = "xluajit"], [
+  # export all symbols to be able to use the Lua FFI interface
+  AC_MSG_NOTICE([Adding -rdynamic to export all symbols for the Lua FFI interface])
+  LDFLAGS="$LDFLAGS -rdynamic"
+])
 PDNS_CHECK_LUA_HPP
 
 AM_CONDITIONAL([HAVE_GNUTLS], [false])
 AM_CONDITIONAL([HAVE_LIBSSL], [false])
+
+PDNS_CHECK_LIBCRYPTO
+
 DNSDIST_ENABLE_DNS_OVER_TLS
+
 AS_IF([test "x$enable_dns_over_tls" != "xno"], [
   DNSDIST_WITH_GNUTLS
   DNSDIST_WITH_LIBSSL
-  AS_IF([test "$HAVE_LIBSSL" = "1"], [
-    # we need libcrypto if libssl is enabled
-    PDNS_CHECK_LIBCRYPTO
-  ])
   AS_IF([test "$HAVE_GNUTLS" = "0" -a "$HAVE_LIBSSL" = "0"], [
     AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
   ])
@@ -110,8 +116,9 @@ AS_IF([test "x$PACKAGEVERSION" != "x"],
 )
 
 AC_CONFIG_FILES([Makefile
-       ext/yahttp/Makefile
-       ext/yahttp/yahttp/Makefile])
+        ext/yahttp/Makefile
+        ext/yahttp/yahttp/Makefile
+        ext/ipcrypt/Makefile])
 
 AC_OUTPUT
 
@@ -146,6 +153,10 @@ AS_IF([test "x$systemd" != "xn"],
   [AC_MSG_NOTICE([systemd: yes])],
   [AC_MSG_NOTICE([systemd: no])]
 )
+AS_IF([test "x$LIBCRYPTO_LIBS" != "x"],
+  [AC_MSG_NOTICE([ipcipher: yes])],
+  [AC_MSG_NOTICE([ipcipher: no])]
+)
 AS_IF([test "x$LIBSODIUM_LIBS" != "x"],
   [AC_MSG_NOTICE([libsodium: yes])],
   [AC_MSG_NOTICE([libsodium: no])]
diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc
new file mode 100644 (file)
index 0000000..169ba64
--- /dev/null
@@ -0,0 +1,58 @@
+
+#include "dnsdist.hh"
+
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP)
+{
+  
+  DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, ids.qname.wirelength(), &ids.origDest, &ids.origRemote, dh, bufferSize, responseLen, isTCP, &ids.sentTime.d_start);
+  dr.origFlags = ids.origFlags;
+  dr.ecsAdded = ids.ecsAdded;
+  dr.ednsAdded = ids.ednsAdded;
+  dr.useZeroScope = ids.useZeroScope;
+  dr.packetCache = std::move(ids.packetCache);
+  dr.delayMsec = ids.delayMsec;
+  dr.skipCache = ids.skipCache;
+  dr.cacheKey = ids.cacheKey;
+  dr.cacheKeyNoECS = ids.cacheKeyNoECS;
+  dr.dnssecOK = ids.dnssecOK;
+  dr.tempFailureTTL = ids.tempFailureTTL;
+  dr.qTag = std::move(ids.qTag);
+  dr.subnet = std::move(ids.subnet);
+#ifdef HAVE_PROTOBUF
+  dr.uniqueId = std::move(ids.uniqueId);
+#endif
+  if (ids.dnsCryptQuery) {
+    dr.dnsCryptQuery = std::move(ids.dnsCryptQuery);
+  }
+
+  return dr;  
+}
+
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
+{
+  ids.origRemote = *dq.remote;
+  ids.origDest = *dq.local;
+  ids.sentTime.set(*dq.queryTime);
+  ids.qname = std::move(qname);
+  ids.qtype = dq.qtype;
+  ids.qclass = dq.qclass;
+  ids.delayMsec = dq.delayMsec;
+  ids.tempFailureTTL = dq.tempFailureTTL;
+  ids.origFlags = dq.origFlags;
+  ids.cacheKey = dq.cacheKey;
+  ids.cacheKeyNoECS = dq.cacheKeyNoECS;
+  ids.subnet = dq.subnet;
+  ids.skipCache = dq.skipCache;
+  ids.packetCache = dq.packetCache;
+  ids.ednsAdded = dq.ednsAdded;
+  ids.ecsAdded = dq.ecsAdded;
+  ids.useZeroScope = dq.useZeroScope;
+  ids.qTag = dq.qTag;
+  ids.dnssecOK = dq.dnssecOK;
+  
+  ids.dnsCryptQuery = std::move(dq.dnsCryptQuery);
+  
+#ifdef HAVE_PROTOBUF
+  ids.uniqueId = std::move(dq.uniqueId);
+#endif
+}
diff --git a/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc
new file mode 100644 (file)
index 0000000..d05882a
--- /dev/null
@@ -0,0 +1,91 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dnsdist-dynblocks.hh"
+
+uint64_t dnsdist_ffi_stat_node_get_queries_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.queries;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_noerrors_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.noerrors;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_nxdomains_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.nxdomains;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_servfails_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.servfails;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_drops_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.drops;
+}
+
+unsigned int dnsdist_ffi_stat_node_get_labels_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->node.labelsCount;
+}
+
+void dnsdist_ffi_stat_node_get_full_name_raw(const dnsdist_ffi_stat_node_t* node, const char** name, size_t* nameSize)
+{
+  const auto& storage = node->node.fullname;
+  *name = storage.c_str();
+  *nameSize = storage.size();
+}
+
+unsigned int dnsdist_ffi_stat_node_get_children_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->node.children.size();
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_queries_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.queries;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_noerrors_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.noerrors;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_nxdomains_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.nxdomains;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_servfails_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.servfails;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_drops_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.drops;
+}
diff --git a/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh b/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh
new file mode 100644 (file)
index 0000000..c4329a2
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+extern "C" {
+  typedef struct dnsdist_ffi_stat_node_t dnsdist_ffi_stat_node_t;
+
+  uint64_t dnsdist_ffi_stat_node_get_queries_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_noerrors_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_nxdomains_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_servfails_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_drops_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  unsigned int dnsdist_ffi_stat_node_get_labels_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  void dnsdist_ffi_stat_node_get_full_name_raw(const dnsdist_ffi_stat_node_t* node, const char** name, size_t* nameSize) __attribute__ ((visibility ("default")));
+
+  unsigned int dnsdist_ffi_stat_node_get_children_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+
+  uint64_t dnsdist_ffi_stat_node_get_children_queries_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_noerrors_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_nxdomains_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_servfails_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_drops_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+}
index 7361a4808c71c4b256e90289bfd98c31f5219de7..bbc4ff2056735d74af71f92d0c8396ad405f67a3 100644 (file)
@@ -863,30 +863,10 @@ public:
       return false;
     }
 
-    uint16_t optStart;
-    size_t optLen = 0;
-    bool last = false;
-    const char * packet = reinterpret_cast<const char*>(dq->dh);
-    std::string packetStr(packet, dq->len);
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return d_extrcode == 0;
-    }
-
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
-      return false;
-    }
-
-    if (optStart < dq->len && packet[optStart] != 0) {
-      // OPT RR Name != '.'
+    EDNS0Record edns0;
+    if (!getEDNS0Record(*dq, edns0)) {
       return false;
     }
-    EDNS0Record edns0;
-    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
-    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
-    memcpy(&edns0, packet + optStart + 5, sizeof edns0);
 
     return d_extrcode == edns0.extRCode;
   }
@@ -907,30 +887,10 @@ public:
   }
   bool matches(const DNSQuestion* dq) const override
   {
-    uint16_t optStart;
-    size_t optLen = 0;
-    bool last = false;
-    const char * packet = reinterpret_cast<const char*>(dq->dh);
-    std::string packetStr(packet, dq->len);
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return false;
-    }
-
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
-      return false;
-    }
-
-    if (optStart < dq->len && packetStr.at(optStart) != 0) {
-      // OPT RR Name != '.'
+    EDNS0Record edns0;
+    if (!getEDNS0Record(*dq, edns0)) {
       return false;
     }
-    EDNS0Record edns0;
-    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
-    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
-    memcpy(&edns0, packet + optStart + 5, sizeof edns0);
 
     return d_version < edns0.version;
   }
@@ -961,8 +921,7 @@ public:
       return false;
     }
 
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
+    if (optLen < optRecordMinimumSize) {
       return false;
     }
 
index dcc787ea42966e464c9613d1acd878ae415d8db5..dc883b85ce640254a59e0d57349cae6470c142b8 100644 (file)
@@ -13,12 +13,16 @@ AXFR or IXFR queries destined to this master. There are two issues that can aris
 The first issue can be solved by routing SOA, AXFR and IXFR requests explicitly to the master::
 
   newServer({address="192.168.1.2", name="master", pool={"master", "otherpool"}})
-  addAction(OrRule({QTypeRule(dnsdist.SOA), QTypeRule(dnsdist.AXFR), QTypeRule(dnsdist.IXFR)}), PoolAction("master"))
+  addAction(OrRule({QTypeRule(DNSQType.SOA), QTypeRule(DNSQType.AXFR), QTypeRule(DNSQType.IXFR)}), PoolAction("master"))
 
 The second one might require allowing AXFR/IXFR from the :program:`dnsdist` source address
 and moving the source address check to :program:`dnsdist`'s side::
 
-  addAction(AndRule({OrRule({QTypeRule(dnsdist.AXFR), QTypeRule(dnsdist.IXFR)}), NotRule(makeRule("192.168.1.0/24"))}), RCodeAction(dnsdist.REFUSED))
+  addAction(AndRule({OrRule({QTypeRule(DNSQType.AXFR), QTypeRule(DNSQTypeIXFR)}), NotRule(makeRule("192.168.1.0/24"))}), RCodeAction(DNSRCode.REFUSED))
+
+.. versionchanged:: 1.4.0
+  Before 1.4.0, the QTypes were in the ``dnsdist`` namespace. Use ``dnsdist.AXFR`` and ``dnsdist.IXFR`` in these versions.
+  Before 1.4.0, the RCodes were in the ``dnsdist`` namespace. Use ``dnsdist.REFUSED`` in these versions.
 
 When :program:`dnsdist` is deployed in front of slaves, however, an issue might arise with NOTIFY
 queries, because the slave will receive a notification coming from the :program:`dnsdist` address,
@@ -26,5 +30,7 @@ and not the master's one. One way to fix this issue is to allow NOTIFY from the
 address on the slave side (for example with PowerDNS's `trusted-notification-proxy`) and move the address
 check to :program:`dnsdist`'s side::
 
-  addAction(AndRule({OpcodeRule(DNSOpcode.Notify), NotRule(makeRule("192.168.1.0/24"))}), RCodeAction(dnsdist.REFUSED))
+  addAction(AndRule({OpcodeRule(DNSOpcode.Notify), NotRule(makeRule("192.168.1.0/24"))}), RCodeAction(DNSRCode.REFUSED))
 
+.. versionchanged:: 1.4.0
+  Before 1.4.0, the RCodes were in the ``dnsdist`` namespace. Use ``dnsdist.REFUSED`` in these versions.
\ No newline at end of file
index 89be18efce60ba1054ec66b7b42667ba334656e8..0655f513e34433731e1f2aac9d854affc7cc12f7 100644 (file)
@@ -4,7 +4,7 @@ Lua actions in rules
 While we can pass every packet through the :func:`blockFilter` functions, it is also possible to configure :program:`dnsdist` to only hand off some packets for Lua inspection. 
 If you think Lua is too slow for your query load, or if you are doing heavy processing in Lua, this may make sense.
 
-To select specific packets for Lua attention, use :func:`addLuaAction` or :func:`addLuaResponseAction`.
+To select specific packets for Lua attention, use :func:`addAction` with :func:`LuaAction`, or :func:`addResponseAction` with :func:`LuaResponseAction`.
 
 A sample configuration could look like this::
 
@@ -17,4 +17,4 @@ A sample configuration could look like this::
     end
   end
 
-  addLuaAction(AllRule(), luarule)
+  addAction(AllRule(), LuaAction(luarule))
index 5b645405d61a9e43d89e83f15a7675bf0b27bba8..ffdbacf6667b55ff02b1b159e66f3e4eac85aab1 100644 (file)
@@ -63,4 +63,4 @@ A working example:
           end
   end
 
-  addLuaAction(AllRule(), pickPool)
+  addAction(AllRule(), LuaAction(pickPool))
index 4e9d28679775967c3f93d56596f3e0a6bfe1b33f..cf72ca7c51559bd2674a9e190075ba034d54d294 100644 (file)
@@ -15,6 +15,7 @@ First, a few words about :program:`dnsdist` architecture:
 The maximum number of threads in the TCP pool is controlled by the :func:`setMaxTCPClientThreads` directive, and defaults to 10.
 This number can be increased to handle a large number of simultaneous TCP connections.
 If all the TCP threads are busy, new TCP connections are queued while they wait to be picked up.
+Before 1.4.0, a TCP thread could only handle a single incoming connection at a time. Starting with 1.4.0 the handling of TCP connections is now event-based, so a single TCP worker can handle a large number of TCP incoming connections simultaneously.
 
 The maximum number of queued connections can be configured with :func:`setMaxTCPQueuedConnections` and defaults to 1000.
 Any value larger than 0 will cause new connections to be dropped if there are already too many queued.
index 48ec0fa6356616ac2b69e3cce58aa699cc3315b7..3fea22211767cf0d592f6ec2d3191eee424cdb67 100644 (file)
@@ -1,6 +1,270 @@
 Changelog
 =========
 
+.. changelog::
+  :version: 1.4.0-alpha1
+  :released: 12th of April 2019
+
+ .. change::
+    :tags: Improvements
+    :pullreq: 7167
+
+    Fix compiler warning about returning garbage (Adam Majer)
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7168
+
+    Fix warnings, mostly unused parameters, reported by -wextra
+
+  .. change::
+    :tags: New Features
+    :pullreq: 6959
+    :tickets: 6941, 2362
+
+    Add namespace and instance variable to carbon key (Gibheer)
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7191
+
+    Add optional uuid column to showServers()
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7087
+
+    Allow NoRecurse for use in dynamic blocks or Lua rules (phonedph1)
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7197
+    :tickets: 7194
+
+    Expose secpoll status
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7026
+
+    Configure --enable-pdns-option --with-third-party-module (Josh Soref)
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7256
+
+    Protect GnuTLS tickets key rotation with a read-write lock
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7267
+
+    Check that ``SO_ATTACH_BPF`` is defined before enabling eBPF
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7138
+
+    Drop remaining capabilities after startup
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7323
+    :tickets: 7236
+
+    Add an optional 'checkTimeout' parameter to 'newServer()'
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7322
+    :tickets: 7237
+
+    Add a 'rise' parameter to 'newServer()'
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7310
+    :tickets: 7239
+
+    Add a 'keepStaleData' option to the packet cache
+
+  .. change::
+    :tags: New Features
+    :pullreq: 6967
+    :tickets: 6846, 6897
+
+    Expose trailing data (Richard Gibson)
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 6634
+
+    More sandboxing using systemd's features
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7426
+
+    Fix off-by-one in mvRule counting
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7428
+
+    Reduce systemcall usage in Protobuf logging
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7433
+
+    Resync YaHTTP code to cmouse/yahttp@11be77a1fc4032 (Chris Hofstaedtler)
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7142
+
+    Add option to set interval between health checks (1848)
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7406
+
+    Add EDNS unknown version handling (Dmitry Alenichev)
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7431
+
+    Pass empty response (Dmitry Alenichev)
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7502
+
+    Change the way getRealMemusage() works on linux (using statm)
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7520
+
+    Don't convert nsec to usec if we need nsec
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7537
+
+    DNSNameSet and QNameSetRule (Andrey)
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7594
+
+    Fix setRules()
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7560
+
+    Handle EAGAIN in the GnuTLS DNS over TLS provider
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7586
+    :tickets: 7461
+
+    Gracefully handle a null latency in the webserver's js
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7585
+    :tickets: 7534
+
+     Prevent 0-ttl cache hits
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7343
+    :tickets: 7139
+
+    Add addDynBlockSMT() support to dynBlockRulesGroup
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7578
+
+    Add frontend response statistics (Matti Hiljanen)
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7652
+
+   EDNSOptionView improvements
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7481
+    :tickets: 6242
+
+    Add support for encrypting ip addresses #gdpr 
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7670
+
+    Remove addLuaAction and addLuaResponseAction
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7559
+    :tickets: 7526, 4814
+
+    Refactoring of the TCP stack
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7674
+    :tickets: 7481
+
+    Honor libcrypto include path
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7677
+    :tickets: 5653
+
+    Add 'setSyslogFacility()'
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7692
+    :tickets: 7556
+
+    Prevent a conflict with BADSIG being clobbered
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7689
+
+    Switch to the new 'newPacketCache()' syntax for 1.4.0
+
+  .. change::
+    :tags: New Features
+    :pullreq: 7676
+
+    Add 'reloadAllCertificates()'
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7678
+
+    Move constants to proper namespace
+
+  .. change::
+    :tags: Improvements
+    :pullreq: 7694
+
+    Unify the management of DNS/DNSCrypt/DoT frontends
+
 .. changelog::
   :version: 1.3.3
   :released: 8th of November 2018
index 1a6bdd3e3a7d7a17a8c3954a7f5de482d919f6bf..167c1503e13c42267ee8762c18740e50d3980fc2 100644 (file)
@@ -5,7 +5,7 @@ Caching Responses
 It is enabled per-pool, but the same cache can be shared between several pools.
 The first step is to define a cache with :func:`newPacketCache`, then to assign that cache to the chosen pool, the default one being represented by the empty string::
 
-  pc = newPacketCache(10000, 86400, 0, 60, 60, false)
+  pc = newPacketCache(10000, {maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60, dontAge=false})
   getPool(""):setCache(pc)
 
  + The first parameter (10000) is the maximum number of entries stored in the cache, and is the only one required. All the other parameters are optional and in seconds, except the last one which is a boolean.
@@ -49,7 +49,10 @@ For example, to remove all expired entries::
 
 Specific entries can also be removed using the :meth:`PacketCache:expungeByName` method::
 
-  getPool("poolname"):getCache():expungeByName(newDNSName("powerdns.com"), dnsdist.A)
+  getPool("poolname"):getCache():expungeByName(newDNSName("powerdns.com"), DNSQType.A)
+
+.. versionchanged:: 1.4.0
+  Before 1.4.0, the QTypes were in the ``dnsdist`` namespace. Use ``dnsdist.A`` in these versions.
 
 Finally, the :meth:`PacketCache:expunge` method will remove all entries until at most n entries remain in the cache::
 
index 191b934d6983d066c3e509230477086b3fd6ac6c..7e7a4495d27b374cd6127beb847c4d90769059ba 100644 (file)
@@ -54,6 +54,7 @@ You can also set the hash perturbation value, see :func:`setWHashedPertubation`.
 ~~~~~~~~~~~~~~
 
 The last available policy is ``roundrobin``, which indiscriminately sends each query to the next server that is up.
+If all servers are down, the policy will still select one server by default. Setting :func:`setRoundRobinFailOnNoServer` to ``true`` will change this behavior.
 
 Lua server policies
 -------------------
@@ -131,7 +132,7 @@ Functions
 
   If set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query.
 
-  :param bool value:
+  :param bool value: whether to return a servfail instead of dropping the query
 
 .. function:: setPoolServerPolicy(policy, pool)
 
@@ -148,6 +149,14 @@ Functions
   :param string function: name of the function
   :param string pool: Name of the pool
 
+.. function:: setRoundRobinFailOnNoServer(value)
+
+  .. versionadded:: 1.4.0
+
+  By default the roundrobin load-balancing policy will still try to select a backend even if all backends are currently down. Setting this to true will make the policy fail and return that no server is available instead.
+
+  :param bool value: whether to fail when all servers are down
+
 .. function:: showPoolServerPolicy(pool)
 
   Print server selection policy for ``pool``.
index b32b427795e512d2e686c98cf2bfa1ea60cabad5..7ecdfa6545f0055a71a299a6fcd0824e063e8f1f 100644 (file)
@@ -22,6 +22,18 @@ ComboAddresses can be IPv4 or IPv6, and unless you want to know, you don't need
 
     Returns the port number.
 
+  .. method:: ComboAddress:ipdecrypt(key) -> ComboAddress
+
+    Decrypt this IP address as described in https://powerdns.org/ipcipher
+
+    :param string key: A 16 byte key. Note that this can be derived from a passphrase with the standalone function `makeIPCipherKey`
+
+  .. method:: ComboAddress:ipencrypt(key) -> ComboAddress
+
+    Encrypt this IP address as described in https://powerdns.org/ipcipher
+
+    :param string key: A 16 byte key. Note that this can be derived from a passphrase with the standalone function `makeIPCipherKey`
+
   .. method:: ComboAddress:isIPv4() -> bool
 
     Returns true if the address is an IPv4, false otherwise
index 62604d792db5f1450979f662579baa0138986189..4252d890a2f2dee69d395064950435e755572999 100644 (file)
@@ -42,6 +42,20 @@ Global configuration
 
   :param str path: The directory to load configuration files from. Each file must end in ``.conf``.
 
+.. function:: reloadAllCertificates()
+
+  .. versionadded:: 1.4.0
+
+  Reload all DNSCrypt and TLS certificates, along with their associated keys.
+
+.. function:: setSyslogFacility(facility)
+
+  .. versionadded:: 1.4.0
+
+  Set the syslog logging facility to ``facility``.
+
+  :param int facility: The new facility as a numeric value. Defaults to LOG_DAEMON.
+
 Listen Sockets
 ~~~~~~~~~~~~~~
 
@@ -310,7 +324,7 @@ Servers
   .. versionchanged:: 1.3.0
     Added ``checkClass``, ``sockets`` and ``checkFunction`` to server_table.
 
-  .. versionchanged:: 1.3.4
+  .. versionchanged:: 1.4.0
     Added ``checkTimeout`` and ``rise`` to server_table.
 
   Add a new backend server. Call this function with either a string::
@@ -537,7 +551,7 @@ PacketCache
 A Pool can have a packet cache to answer queries directly in stead of going to the backend.
 See :doc:`../guides/cache` for a how to.
 
-.. function:: newPacketCache(maxEntries[, maxTTL=86400[, minTTL=0[, temporaryFailureTTL=60[, staleTTL=60[, dontAge=false[, numberOfShards=1[, deferrableInsertLock=true[, maxNegativeTTL=3600[, parseECS=false [,options]]]]]]]]) -> PacketCache
+.. function:: newPacketCache(maxEntries[, maxTTL=86400[, minTTL=0[, temporaryFailureTTL=60[, staleTTL=60[, dontAge=false[, numberOfShards=1[, deferrableInsertLock=true[, maxNegativeTTL=3600[, parseECS=false]]]]]]]) -> PacketCache
 
   .. versionchanged:: 1.3.0
     ``numberOfShards`` and ``deferrableInsertLock`` parameters added.
@@ -545,11 +559,9 @@ See :doc:`../guides/cache` for a how to.
   .. versionchanged:: 1.3.1
     ``maxNegativeTTL`` and ``parseECS`` parameters added.
 
-  .. versionchanged:: 1.3.4
-    ``options`` parameter added.
+  .. deprecated:: 1.4.0
 
   Creates a new :class:`PacketCache` with the settings specified.
-  Starting with 1.3.4, all parameters can be specified in the ``options`` table, overriding the value from the existing parameters if any.
 
   :param int maxEntries: The maximum number of entries in this cache
   :param int maxTTL: Cap the TTL for records to his number
@@ -561,7 +573,14 @@ See :doc:`../guides/cache` for a how to.
   :param bool deferrableInsertLock: Whether the cache should give up insertion if the lock is held by another thread, or simply wait to get the lock
   :param int maxNegativeTTL: Cache a NXDomain or NoData answer from the backend for at most this amount of seconds, even if the TTL of the SOA record is higher
   :param bool parseECS: Whether any EDNS Client Subnet option present in the query should be extracted and stored to be able to detect hash collisions involving queries with the same qname, qtype and qclass but a different incoming ECS value. Enabling this option adds a parsing cost and only makes sense if at least one backend might send different responses based on the ECS value, so it's disabled by default
-  :param table options: A table with key: value pairs with the options listed below:
+
+.. function:: newPacketCache(maxEntries, [options]) -> PacketCache
+
+  .. versionadded:: 1.4.0
+
+  Creates a new :class:`PacketCache` with the settings specified.
+
+  :param int maxEntries: The maximum number of entries in this cache
 
   Options:
 
@@ -595,7 +614,7 @@ See :doc:`../guides/cache` for a how to.
 
     :param int n: Number of entries to keep
 
-  .. method:: PacketCache:expungeByName(name [, qtype=dnsdist.ANY[, suffixMatch=false]])
+  .. method:: PacketCache:expungeByName(name [, qtype=DNSQType.ANY[, suffixMatch=false]])
 
     .. versionchanged:: 1.2.0
       ``suffixMatch`` parameter added.
@@ -603,12 +622,12 @@ See :doc:`../guides/cache` for a how to.
     Remove entries matching ``name`` and type from the cache.
 
     :param DNSName name: The name to expunge
-    :param int qtype: The type to expunge
+    :param int qtype: The type to expunge, can be a pre-defined :ref:`DNSQType`
     :param bool suffixMatch: When set to true, remove al entries under ``name``
 
   .. method:: PacketCache:getStats()
 
-    .. versionadded:: 1.3.4
+    .. versionadded:: 1.4.0
 
     Return the cache stats (number of entries, hits, misses, deferred lookups, deferred inserts, lookup collisions, insert collisions and TTL too shorts) as a Lua table.
 
@@ -718,7 +737,7 @@ Status, Statistics and More
 
 .. function:: showServers([options])
 
-  .. versionchanged:: 1.3.4
+  .. versionchanged:: 1.4.0
     ``options`` optional parameter added
 
   This function shows all backend servers currently configured and some statistics.
index 6ee99ef8e5480cdc2e12485045f2858fffb72b5a..03793cfa2cb6a59eea53c15d92979a074295f1f2 100644 (file)
@@ -8,6 +8,8 @@ There are many constants in :program:`dnsdist`.
 OPCode
 ------
 
+These constants represent the `OpCode <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5>`__ of a query.
+
 - ``DNSOpcode.Query``
 - ``DNSOpcode.IQuery``
 - ``DNSOpcode.Status``
@@ -16,15 +18,17 @@ OPCode
 
 Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5
 
-.. _DNSQClass:
+.. _DNSClass:
 
-QClass
+DNSClass
 ------
 
-- ``QClass.IN``
-- ``QClass.CHAOS``
-- ``QClass.NONE``
-- ``QClass.ANY``
+These constants represent the `CLASS <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-2>`__ of a DNS record.
+
+- ``DNSClass.IN``
+- ``DNSClass.CHAOS``
+- ``DNSClass.NONE``
+- ``DNSClass.ANY``
 
 Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-2
 
@@ -33,31 +37,34 @@ Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#
 RCode
 -----
 
-- ``dnsdist.NOERROR``
-- ``dnsdist.FORMERR``
-- ``dnsdist.SERVFAIL``
-- ``dnsdist.NXDOMAIN``
-- ``dnsdist.NOTIMP``
-- ``dnsdist.REFUSED``
-- ``dnsdist.YXDOMAIN``
-- ``dnsdist.YXRRSET``
-- ``dnsdist.NXRRSET``
-- ``dnsdist.NOTAUTH``
-- ``dnsdist.NOTZONE``
-- ``dnsdist.BADVERS``
-- ``dnsdist.BADSIG``
-- ``dnsdist.BADKEY``
-- ``dnsdist.BADTIME``
-- ``dnsdist.BADMODE``
-- ``dnsdist.BADNAME``
-- ``dnsdist.BADALG``
-- ``dnsdist.BADTRUNC``
-- ``dnsdist.BADCOOKIE``
-
-RCodes below and including ``BADVERS`` are extended RCodes that can only be matched using :func:`ERCodeRule`.
-
-Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6
-
+These constants represent the different `RCODEs <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6>`__ for DNS messages.
+
+.. versionchanged:: 1.4.0
+  The prefix is changed from ``dnsdist`` to ``DNSRCode``.
+
+- ``DNSRCode.NOERROR``
+- ``DNSRCode.FORMERR``
+- ``DNSRCode.SERVFAIL``
+- ``DNSRCode.NXDOMAIN``
+- ``DNSRCode.NOTIMP``
+- ``DNSRCode.REFUSED``
+- ``DNSRCode.YXDOMAIN``
+- ``DNSRCode.YXRRSET``
+- ``DNSRCode.NXRRSET``
+- ``DNSRCode.NOTAUTH``
+- ``DNSRCode.NOTZONE``
+
+RCodes below are extended RCodes that can only be matched using :func:`ERCodeRule`.
+
+- ``DNSRCode.BADVERS``
+- ``DNSRCode.BADSIG``
+- ``DNSRCode.BADKEY``
+- ``DNSRCode.BADTIME``
+- ``DNSRCode.BADMODE``
+- ``DNSRCode.BADNAME``
+- ``DNSRCode.BADALG``
+- ``DNSRCode.BADTRUNC``
+- ``DNSRCode.BADCOOKIE``
 
 .. _EDNSOptionCode:
 
@@ -80,8 +87,10 @@ Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#
 
 .. _DNSSection:
 
-DNS Section
------------
+DNS Packet Sections
+-------------------
+
+These constants represent the section in the DNS Packet.
 
 - ``DNSSection.Question``
 - ``DNSSection.Answer``
@@ -93,7 +102,7 @@ DNS Section
 DNSAction
 ---------
 
-These constants represent an Action that can be returned from the functions invoked by :func:`addLuaAction`.
+These constants represent an Action that can be returned from :func:`LuaAction` functions.
 
  * ``DNSAction.Allow``: let the query pass, skipping other rules
  * ``DNSAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
@@ -109,13 +118,29 @@ These constants represent an Action that can be returned from the functions invo
  * ``DNSAction.Truncate``: truncate the response
  * ``DNSAction.NoRecurse``: set rd=0 on the query
 
+.. _DNSQType:
+
+DNSQType
+--------
+
+.. versionchanged:: 1.3.0
+  The prefix is changed from ``dnsdist.`` to ``DNSQType``.
+
+All named `QTypes <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4>`__ are available as constants, prefixed with ``DNSQType.``, e.g.:
+
+ * ``DNSQType.AAAA``
+ * ``DNSQType.AXFR``
+ * ``DNSQType.A``
+ * ``DNSQType.NS``
+ * ``DNSQType.SOA``
+ * etc.
 
 .. _DNSResponseAction:
 
 DNSResponseAction
 -----------------
 
-These constants represent an Action that can be returned from the functions invoked by :func:`addLuaResponseAction`.
+These constants represent an Action that can be returned from :func:`LuaResponseAction` functions.
 
  * ``DNSResponseAction.Allow``: let the response pass, skipping other rules
  * ``DNSResponseAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
index c0fc2702d32593d99b83851728d818a23f2671c1..e451eb62e06616619ff862d6cd75970def55b775 100644 (file)
@@ -47,7 +47,7 @@ This state can be modified from the various hooks.
   .. attribute:: DNSQuestion.qtype
 
     QType (as an unsigned integer) of this question.
-    Can be compared against ``dnsdist.A``, ``dnsdist.AAAA`` etc.
+    Can be compared against the pre-defined :ref:`constants <DNSQType>` like ``DNSQType.A``, DNSQType.AAAA``.
 
   .. attribute:: DNSQuestion.remoteaddr
 
@@ -170,7 +170,7 @@ DNSResponse object
 
     - ``section`` is the section in the packet and can be compared to :ref:`DNSSection`
     - ``qclass`` is the QClass of the record. Can be compared to :ref:`DNSQClass`
-    - ``qtype`` is the QType of the record. Can be e.g. compared to ``dnsdist.A``, ``dnsdist.AAAA`` and the like.
+    - ``qtype`` is the QType of the record. Can be e.g. compared to ``DNSQType.A``, ``DNSQType.AAAA`` :ref:`constants <DNSQType>` and the like.
     - ``ttl`` is the current TTL
 
     This function must return an integer with the new TTL.
@@ -231,7 +231,7 @@ EDNSOptionView object
 
   An object that represents the values of a single EDNS option received in a query.
 
-  .. attribute:: EDNSOptionView.count -> int
+  .. method:: EDNSOptionView:count()
 
     The number of values for this EDNS option.
 
index 54d56823627cfeeaabaf7c629fe5578b3cffe786..0f63a5145446e0000b1a4a8248a5c6bc87617566 100644 (file)
@@ -3,7 +3,7 @@ Tuning related functions
 
 .. function:: setMaxTCPClientThreads(num)
 
-  Set the maximum of TCP client threads, handling TCP connections
+  Set the maximum of TCP client threads, handling TCP connections. Before 1.4.0 a TCP thread could only handle a single incoming TCP connection at a time, while after 1.4.0 it can handle a larger number of them simultaneously.
 
   :param int num:
 
index 36d8b3a55b5702f60250d3f157c1136dcd9e2c42..f2ad20a6d30e33652c1b39ed4f4d478f73a4a498 100644 (file)
@@ -81,7 +81,10 @@ Rule Generators
   Set the TC-bit (truncate) on ANY queries received over UDP, forcing a retry over TCP.
   This function is deprecated as of 1.2.0 and will be removed in 1.3.0. This is equivalent to doing::
 
-    addAction(AndRule({QTypeRule(dnsdist.ANY), TCPRule(false)}), TCAction())
+   addAction(AndRule({QTypeRule(DNSQType.ANY), TCPRule(false)}), TCAction())
+
+  .. versionchanged:: 1.4.0
+    Before 1.4.0, the QTypes were in the ``dnsdist`` namespace. Use ``dnsdist.ANY`` in these versions.
 
 .. function:: addDelay(DNSrule, delay)
 
@@ -151,6 +154,9 @@ Rule Generators
   .. versionchanged:: 1.3.0
     The second argument returned by the ``function`` can be omitted. For earlier releases, simply return an empty string.
 
+  .. deprecated:: 1.4.0
+    Removed in 1.4.0, use :func:`LuaAction` with :func:`addAction` instead.
+
   Invoke a Lua function that accepts a :class:`DNSQuestion`.
   This function works similar to using :func:`LuaAction`.
   The ``function`` should return both a :ref:`DNSAction` and its argument `rule`. The `rule` is used as an argument
@@ -187,6 +193,9 @@ Rule Generators
   .. versionchanged:: 1.3.0
     The second argument returned by the ``function`` can be omitted. For earlier releases, simply return an empty string.
 
+  .. deprecated:: 1.4.0
+    Removed in 1.4.0, use :func:`LuaResponseAction` with :func:`addResponseAction` instead.
+
   Invoke a Lua function that accepts a :class:`DNSResponse`.
   This function works similar to using :func:`LuaResponseAction`.
   The ``function`` should return both a :ref:`DNSResponseAction` and its argument `rule`. The `rule` is used as an argument
@@ -712,7 +721,7 @@ These ``DNSRule``\ s be one of the following items:
 
   Matches if there is at least ``minCount`` and at most ``maxCount`` records of type ``type`` in the section ``section``.
   ``section`` can be specified as an integer or as a ref:`DNSSection`.
-  ``qtype`` may be specified as an integer or as one of the built-in QTypes, for instance ``dnsdist.A`` or ``dnsdist.TXT``.
+  ``qtype`` may be specified as an integer or as one of the :ref:`built-in QTypes <DNSQType>`, for instance ``DNSQType.A`` or ``DNSQType.TXT``.
 
   :param int section: The section to match on
   :param int qtype: The QTYPE to match on
@@ -963,6 +972,9 @@ The following actions exist.
   .. versionchanged:: 1.3.0
     ``options`` optional parameter added.
 
+  .. versionchanged:: 1.4.0
+    ``ipEncryptKey`` optional key added to the options table.
+
   Send the content of this query to a remote logger via Protocol Buffer.
   ``alterFunction`` is a callback, receiving a :class:`DNSQuestion` and a :class:`DNSDistProtoBufMessage`, that can be used to modify the Protocol Buffer content, for example for anonymization purposes
 
@@ -973,12 +985,16 @@ The following actions exist.
   Options:
 
   * ``serverID=""``: str - Set the Server Identity field.
+  * ``ipEncryptKey=""``: str - A key, that can be generated via the :ref:`makeIPCipherKey` function, to encrypt the IP address of the requestor for anonymization purposes. The encryption is done using ipcrypt for IPv4 and a 128-bit AES ECB operation for IPv6.
 
 .. function:: RemoteLogResponseAction(remoteLogger[, alterFunction[, includeCNAME [, options]]])
 
   .. versionchanged:: 1.3.0
     ``options`` optional parameter added.
 
+  .. versionchanged:: 1.4.0
+    ``ipEncryptKey`` optional key added to the options table.
+
   Send the content of this response to a remote logger via Protocol Buffer.
   ``alterFunction`` is the same callback that receiving a :class:`DNSQuestion` and a :class:`DNSDistProtoBufMessage`, that can be used to modify the Protocol Buffer content, for example for anonymization purposes
   ``includeCNAME`` indicates whether CNAME records inside the response should be parsed and exported.
@@ -992,6 +1008,7 @@ The following actions exist.
   Options:
 
   * ``serverID=""``: str - Set the Server Identity field.
+  * ``ipEncryptKey=""``: str - A key, that can be generated via the :ref:`makeIPCipherKey` function, to encrypt the IP address of the requestor for anonymization purposes. The encryption is done using ipcrypt for IPv4 and a 128-bit AES ECB operation for IPv6.
 
 .. function:: SetECSAction(v4 [, v6])
 
index a56e9e0fa8d83e663c0dfb6009afa4b92be9d9d9..10f7de4ff1ba49c53c23e7887f9bf64228a9cdc5 100644 (file)
@@ -67,6 +67,18 @@ fd-usage
 --------
 Number of currently used file descriptors.
 
+frontend-noerror
+----------------
+Number of NoError answers sent to clients.
+
+frontend-nxdomain
+-----------------
+Number of NXDomain answers sent to clients.
+
+frontend-servfail
+-----------------
+Number of ServFail answers sent to clients.
+
 latency-avg100
 --------------
 Average response latency in microseconds of the last 100 packets
index 920df10922323ce03e1084b5cca271bb09ca6954..4de4298cfb40aa3e2df15dc65ad0cf47edd1c259 100644 (file)
@@ -1,6 +1,25 @@
 Upgrade Guide
 =============
 
+1.3.x to 1.4.0
+--------------
+
+:func:`addLuaAction` and :func:`addLuaResponseAction` have been removed. Instead, use :func:`addAction` with a :func:`LuaAction`, or :func:`addResponseAction` with a :func:`LuaResponseAction`.
+
+:func:`newPacketCache` now takes an optional table as its second argument, instead of several optional parameters.
+
+Lua's constants for DNS response codes and QTypes have been moved from the 'dnsdist' prefix to, respectively, the 'DNSQType' and 'DNSRCode' prefix.
+
+To improve security, all ambient capabilities are now dropped after the startup phase, which might prevent launching the webserver on a privileged port at run-time, or impact some custom Lua code. In addition, systemd's sandboxing features are now determined at compile-time, resulting in more restrictions on recent distributions. See pull requests 7138 and 6634 for more information.
+
+If you are compiling dnsdist, note that several ./configure options have been renamed to provide a more consistent experience. Features that depend on an external component have been prefixed with '--with-' while internal features use '--enable-'. This lead to the following changes:
+
+- ``--enable-fstrm`` to ``--enable-dnstap``
+- ``--enable-gnutls`` to ``--with-gnutls``
+- ``--enable-libsodium`` to ``--with-libsodium``
+- ``--enable-libssl`` to ``--with-libssl``
+- ``--enable-re2`` to ``--with-re2``
+
 1.3.2 to 1.3.3
 --------------
 
diff --git a/pdns/dnsdistdist/ext/ipcrypt/LICENSE b/pdns/dnsdistdist/ext/ipcrypt/LICENSE
new file mode 120000 (symlink)
index 0000000..2e88816
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/LICENSE
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/Makefile.am b/pdns/dnsdistdist/ext/ipcrypt/Makefile.am
new file mode 120000 (symlink)
index 0000000..8111f1d
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/Makefile.am
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c b/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c
new file mode 120000 (symlink)
index 0000000..e7d1241
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/ipcrypt.c
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h b/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h
new file mode 120000 (symlink)
index 0000000..ad5c2b8
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/ipcrypt.h
\ No newline at end of file
index 641a2797b4a72d8169be528b9973bf09475ca0e7..49304f125c00bb40a93ee8bbf00e9a8b73f2502b 100644 (file)
@@ -194,7 +194,8 @@ $(document).ready(function() {
                      var bouw='<table width="100%"><tr align=right><th>#</th><th align=left>Name</th><th align=left>Address</th><th>Status</th><th>Latency</th><th>Queries</th><th>Drops</th><th>QPS</th><th>Out</th><th>Weight</th><th>Order</th><th align=left>Pools</th></tr>';
                      $.each(data["servers"], function(a,b) {
                          bouw = bouw + ("<tr align=right><td>"+b["id"]+"</td><td align=left>"+b["name"]+"</td><td align=left>"+b["address"]+"</td><td>"+b["state"]+"</td>");
-                         bouw = bouw + ("<td>"+(b["latency"]).toFixed(2)+"</td><td>"+b["queries"]+"</td><td>"+b["reuseds"]+"</td><td>"+(b["qps"]).toFixed(2)+"</td><td>"+b["outstanding"]+"</td>");
+                         var latency = (b["latency"] === null) ? 0.0 : b["latency"];
+                         bouw = bouw + ("<td>"+latency.toFixed(2)+"</td><td>"+b["queries"]+"</td><td>"+b["reuseds"]+"</td><td>"+(b["qps"]).toFixed(2)+"</td><td>"+b["outstanding"]+"</td>");
                          bouw = bouw + ("<td>"+b["weight"]+"</td><td>"+b["order"]+"</td><td align=left>"+b["pools"]+"</td></tr>");
                      }); 
                      bouw = bouw + "</table>";
diff --git a/pdns/dnsdistdist/ipcipher.cc b/pdns/dnsdistdist/ipcipher.cc
new file mode 120000 (symlink)
index 0000000..794b273
--- /dev/null
@@ -0,0 +1 @@
+../ipcipher.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ipcipher.hh b/pdns/dnsdistdist/ipcipher.hh
new file mode 120000 (symlink)
index 0000000..e8d5917
--- /dev/null
@@ -0,0 +1 @@
+../ipcipher.hh
\ No newline at end of file
index 2be4a4c62fe4675b13468dc0271b5f696c69f9bb..32731c85b5a5c21dc174c819f1e0f5cca34e899d 100644 (file)
@@ -232,7 +232,7 @@ private:
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
-  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free))
+  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -247,12 +247,62 @@ public:
     if (!SSL_set_fd(d_conn.get(), d_socket)) {
       throw std::runtime_error("Error assigning socket");
     }
+  }
+
+  IOState convertIORequestToIOState(int res) const
+  {
+    int error = SSL_get_error(d_conn.get(), res);
+    if (error == SSL_ERROR_WANT_READ) {
+      return IOState::NeedRead;
+    }
+    else if (error == SSL_ERROR_WANT_WRITE) {
+      return IOState::NeedWrite;
+    }
+    else if (error == SSL_ERROR_SYSCALL) {
+      throw std::runtime_error("Error while processing TLS connection:" + std::string(strerror(errno)));
+    }
+    else {
+      throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error));
+    }
+  }
+
+  void handleIORequest(int res, unsigned int timeout)
+  {
+    auto state = convertIORequestToIOState(res);
+    if (state == IOState::NeedRead) {
+      res = waitForData(d_socket, timeout);
+      if (res <= 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+    }
+    else if (state == IOState::NeedWrite) {
+      res = waitForRWData(d_socket, false, timeout, 0);
+      if (res <= 0) {
+        throw std::runtime_error("Error waiting to write to TLS connection");
+      }
+    }
+  }
+
+  IOState tryHandshake() override
+  {
+    int res = SSL_accept(d_conn.get());
+    if (res == 1) {
+      return IOState::Done;
+    }
+    else if (res < 0) {
+      return convertIORequestToIOState(res);
+    }
 
+    throw std::runtime_error("Error accepting TLS connection");
+  }
+
+  void doHandshake() override
+  {
     int res = 0;
     do {
       res = SSL_accept(d_conn.get());
       if (res < 0) {
-        handleIORequest(res, timeout);
+        handleIORequest(res, d_timeout);
       }
     }
     while (res < 0);
@@ -262,24 +312,40 @@ public:
     }
   }
 
-  void handleIORequest(int res, unsigned int timeout)
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
   {
-    int error = SSL_get_error(d_conn.get(), res);
-    if (error == SSL_ERROR_WANT_READ) {
-      res = waitForData(d_socket, timeout);
-      if (res <= 0) {
-        throw std::runtime_error("Error reading from TLS connection");
+    do {
+      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
       }
-    }
-    else if (error == SSL_ERROR_WANT_WRITE) {
-      res = waitForRWData(d_socket, false, timeout, 0);
-      if (res <= 0) {
-        throw std::runtime_error("Error waiting to write to TLS connection");
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
       }
     }
-    else {
-      throw std::runtime_error("Error writing to TLS connection");
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
+      }
     }
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -300,7 +366,7 @@ public:
         handleIORequest(res, readTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
 
       if (totalTimeout) {
@@ -330,7 +396,7 @@ public:
         handleIORequest(res, writeTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
     }
     while (got < bufferSize);
@@ -346,6 +412,7 @@ public:
 
 private:
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+  unsigned int d_timeout;
 };
 
 class OpenSSLTLSIOCtx: public TLSCtx
@@ -650,7 +717,7 @@ public:
 
   GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
   {
-    unsigned int sslOptions = GNUTLS_SERVER;
+    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
     sslOptions |= GNUTLS_NO_SIGNAL;
 #endif
@@ -685,12 +752,86 @@ public:
     /* timeouts are in milliseconds */
     gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
+  }
 
+  void doHandshake() override
+  {
     int ret = 0;
     do {
       ret = gnutls_handshake(d_conn.get());
+      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    }
+    while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+  }
+
+  IOState tryHandshake() override
+  {
+    int ret = 0;
+
+    do {
+      ret = gnutls_handshake(d_conn.get());
+      if (ret == GNUTLS_E_SUCCESS) {
+        return IOState::Done;
+      }
+      else if (ret == GNUTLS_E_AGAIN) {
+        return IOState::NeedRead;
+      }
+      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    } while (ret == GNUTLS_E_INTERRUPTED);
+
+    throw std::runtime_error("Error accepting a new connection");
+  }
+
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
+  {
+    do {
+      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error writing to TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedWrite;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
+    }
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error reading from TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedRead;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
     }
-    while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -708,7 +849,7 @@ public:
         throw std::runtime_error("Error reading from TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
@@ -750,7 +891,7 @@ public:
         throw std::runtime_error("Error writing to TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
@@ -817,7 +958,7 @@ public:
 
     rc = gnutls_priority_init(&d_priorityCache, fe.d_ciphers.empty() ? "NORMAL" : fe.d_ciphers.c_str(), nullptr);
     if (rc != GNUTLS_E_SUCCESS) {
-      warnlog("Error setting up TLS cipher preferences to %s (%s), skipping.", fe.d_ciphers.c_str(), gnutls_strerror(rc));
+      throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
     }
 
     pthread_rwlock_init(&d_lock, nullptr);
index 560c222b2958fb77203f4885a3f950bd95640bce..b191a268a5e094b35867ca456727f16ab374cba7 100644 (file)
@@ -10,6 +10,7 @@
 
 Rings g_rings;
 GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
 
 BOOST_AUTO_TEST_SUITE(dnsdistdynblocks_hh)
 
index cd2a22158d5044d7e706d1240bd62fc3717d8231..357c902db74837e5f9313f997f965fd2ed6ab425 100644 (file)
@@ -22,7 +22,8 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   uint16_t qclass = QClass::IN;
   ComboAddress lc("127.0.0.1:53");
   ComboAddress rem("192.0.2.1:42");
-  struct dnsheader* dh = nullptr;
+  struct dnsheader dh;
+  memset(&dh, 0, sizeof(dh));
   size_t bufferSize = 0;
   size_t queryLen = 0;
   bool isTcp = false;
@@ -32,7 +33,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   /* the internal QPS limiter does not use the real time */
   gettime(&expiredTime);
 
-  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, bufferSize, queryLen, isTcp, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime);
 
   for (size_t idx = 0; idx < maxQPS; idx++) {
     /* let's use different source ports, it shouldn't matter */
diff --git a/pdns/dnsdistdist/test-mplexer.cc b/pdns/dnsdistdist/test-mplexer.cc
new file mode 120000 (symlink)
index 0000000..f406267
--- /dev/null
@@ -0,0 +1 @@
+../test-mplexer.cc
\ No newline at end of file
index bd8e23dad5ecce0e5ae2a355f305249e54079acc..3ae3ea7ee3886043f7fdf6c37509ea630cd09bb7 100644 (file)
@@ -40,19 +40,40 @@ otherwise, obfuscate the response IP address
 #include "statbag.hh"
 #include "dnspcap.hh"
 #include "iputils.hh"
-
+#include "ipcipher.hh"
 #include "namespaces.hh"
+#include <boost/program_options.hpp>
+#include "base64.hh"
 
 StatBag S;
 
+namespace po = boost::program_options;
+po::variables_map g_vm;
+
+
 class IPObfuscator
 {
 public:
-  IPObfuscator() : d_romap(d_ipmap), d_ro6map(d_ip6map), d_counter(0)
+  virtual uint32_t obf4(uint32_t orig)=0;
+  virtual struct in6_addr obf6(const struct in6_addr& orig)=0;
+};
+
+class IPSeqObfuscator : public IPObfuscator
+{
+public:
+  IPSeqObfuscator() : d_romap(d_ipmap), d_ro6map(d_ip6map), d_counter(0)
   {
   }
 
-  uint32_t obf4(uint32_t orig)
+  ~IPSeqObfuscator()
+  {}
+
+  static std::unique_ptr<IPObfuscator> make()
+  {
+    return std::unique_ptr<IPObfuscator>(new IPSeqObfuscator());
+  }
+
+  uint32_t obf4(uint32_t orig) override
   {
     if(d_romap.count(orig))
       return d_ipmap[orig];
@@ -61,7 +82,7 @@ public:
     }
   }
 
-  struct in6_addr obf6(const struct in6_addr& orig)
+  struct in6_addr obf6(const struct in6_addr& orig) override
   {
     uint32_t val;
     if(d_ro6map.count(orig))
@@ -95,6 +116,48 @@ private:
   uint32_t d_counter;
 };
 
+class IPCipherObfuscator : public IPObfuscator
+{
+public:
+  IPCipherObfuscator(const std::string& key, bool decrypt)  : d_key(key), d_decrypt(decrypt)
+  {
+    if(d_key.size()!=16) {
+      throw std::runtime_error("IPCipher requires a 128 bit key");
+    }
+  }
+
+  ~IPCipherObfuscator()
+  {}
+  static std::unique_ptr<IPObfuscator> make(std::string key, bool decrypt)
+  {
+    return std::unique_ptr<IPObfuscator>(new IPCipherObfuscator(key, decrypt));
+  }
+
+  uint32_t obf4(uint32_t orig) override
+  {
+    ComboAddress ca;
+    ca.sin4.sin_family = AF_INET;
+    ca.sin4.sin_addr.s_addr = orig;
+    ca = d_decrypt ? decryptCA(ca, d_key) : encryptCA(ca, d_key);
+    return ca.sin4.sin_addr.s_addr;
+
+  }
+
+  struct in6_addr obf6(const struct in6_addr& orig) override
+  {
+    ComboAddress ca;
+    ca.sin4.sin_family = AF_INET6;
+    ca.sin6.sin6_addr = orig;
+    ca = d_decrypt ? decryptCA(ca, d_key) : encryptCA(ca, d_key);
+    return ca.sin6.sin6_addr;
+  }
+
+private:
+  std::string d_key;
+  bool d_decrypt;
+};
+
+
 void usage() {
   cerr<<"Syntax: dnswasher INFILE1 [INFILE2..] OUTFILE"<<endl;
 }
@@ -102,53 +165,98 @@ void usage() {
 int main(int argc, char** argv)
 try
 {
-  for (int i = 1; i < argc; i++) {
-    if ((string) argv[i] == "--help") {
-      usage();
-      exit(EXIT_SUCCESS);
-    }
+  po::options_description desc("Allowed options");
+  desc.add_options()
+    ("help,h", "produce help message")
+    ("version", "show version number")
+    ("key,k", po::value<string>(), "base64 encoded 128 bit key for ipcipher")
+    ("passphrase,p", po::value<string>(), "passphrase for ipcipher (will be used to derive key)")
+    ("decrypt,d", "decrypt IP addresses with ipcipher");
 
-    if ((string) argv[i] == "--version") {
-      cerr<<"dnswasher "<<VERSION<<endl;
-      exit(EXIT_SUCCESS);
-    }
+  po::options_description alloptions;
+  po::options_description hidden("hidden options");
+  hidden.add_options()
+    ("infiles", po::value<vector<string>>(), "PCAP source file(s)")
+    ("outfile", po::value<string>(), "outfile");
+
+
+  alloptions.add(desc).add(hidden);
+  po::positional_options_description p;
+  p.add("infiles", 1);
+  p.add("outfile", 1);
+
+  po::store(po::command_line_parser(argc, argv).options(alloptions).positional(p).run(), g_vm);
+  po::notify(g_vm);
+
+  if(g_vm.count("help")) {
+    usage();
+    cout<<desc<<endl;
+    exit(EXIT_SUCCESS);
+  }
+
+  if(g_vm.count("version")) {
+    cout<<"dnswasher "<<VERSION<<endl;
+    exit(EXIT_SUCCESS);
   }
 
-  if(argc < 3) {
+  if(!g_vm.count("outfile")) {
+    cout<<"Missing outfile"<<endl;
     usage();
-    exit(1);
+    exit(EXIT_FAILURE);
   }
 
-  PcapPacketWriter pw(argv[argc-1]);
-  IPObfuscator ipo;
-  // 0          1   2   3    - argc == 4
-  // dnswasher in1 in2 out
-  for(int n=1; n < argc -1; ++n) {
-    PcapPacketReader pr(argv[n]);
+  bool doDecrypt = g_vm.count("decrypt");
+
+  PcapPacketWriter pw(g_vm["outfile"].as<string>());
+  std::unique_ptr<IPObfuscator> ipo;
+
+  if(!g_vm.count("key") && !g_vm.count("passphrase"))
+    ipo = IPSeqObfuscator::make();
+  else if(g_vm.count("key") && !g_vm.count("passphrase")) {
+    string key;
+    if(B64Decode(g_vm["key"].as<string>(), key) < 0) {
+      cerr<<"Invalidly encoded base64 key provided"<<endl;
+      exit(EXIT_FAILURE);
+    }
+    ipo = IPCipherObfuscator::make(key, doDecrypt);
+  }
+  else if(!g_vm.count("key") && g_vm.count("passphrase")) {
+    string key = makeIPCipherKey(g_vm["passphrase"].as<string>());
+
+    ipo = IPCipherObfuscator::make(key, doDecrypt);
+  }
+  else {
+    cerr<<"Can't specify both 'key' and 'passphrase'"<<endl;
+    exit(EXIT_FAILURE);
+  }
+
+  for(const auto& inf : g_vm["infiles"].as<vector<string>>()) {
+    PcapPacketReader pr(inf);
     pw.setPPR(pr);
 
     while(pr.getUDPPacket()) {
       if(ntohs(pr.d_udp->uh_dport)==53 || (ntohs(pr.d_udp->uh_sport)==53 && pr.d_len > sizeof(dnsheader))) {
         dnsheader* dh=(dnsheader*)pr.d_payload;
-        
+
         if (pr.d_ip->ip_v == 4){
           uint32_t *src=(uint32_t*)&pr.d_ip->ip_src;
           uint32_t *dst=(uint32_t*)&pr.d_ip->ip_dst;
-          
+
           if(dh->qr)
-            *dst=htonl(ipo.obf4(*dst));
+            *dst=ipo->obf4(*dst);
           else
-            *src=htonl(ipo.obf4(*src));
-          
+            *src=ipo->obf4(*src);
+
           pr.d_ip->ip_sum=0;
         } else if (pr.d_ip->ip_v == 6) {
           auto src=&pr.d_ip6->ip6_src;
           auto dst=&pr.d_ip6->ip6_dst;
-          
+
           if(dh->qr)
-            *dst=ipo.obf6(*dst);
+            *dst=ipo->obf6(*dst);
           else
-            *src=ipo.obf6(*src);
+            *src=ipo->obf6(*src);
+          // IPv6 checksum does not cover source/destination addresses
         }
         pw.write();
       }
index 0d5bffe0d93040f5be3cf539bfcefe592f40a984..cce1b74769347ac2d40cfb7d73dec83e69ba29a4 100644 (file)
@@ -70,6 +70,13 @@ void dolog(std::ostream& os, const char* s, T value, Args... args)
 extern bool g_verbose;
 extern bool g_syslog;
 
+inline void setSyslogFacility(int facility)
+{
+  /* we always call openlog() right away at startup */
+  closelog();
+  openlog("dnsdist", LOG_PID|LOG_NDELAY, facility);
+}
+
 template<typename... Args>
 void genlog(int level, const char* s, Args... args)
 {
index 97d7e82ff6a887a0799e29e6a5873fa4614bead7..433687d21456f8c24e7227683bb39bcdf719a04d 100644 (file)
@@ -45,7 +45,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -94,9 +94,9 @@ EpollFDMultiplexer::EpollFDMultiplexer() : d_eevents(new epoll_event[s_maxevents
     
 }
 
-void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct epoll_event eevent;
   
@@ -156,13 +156,13 @@ int EpollFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(d_eevents[n].data.fd);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't refind ourselves as writable!
     }
     d_iter=d_writeCallbacks.find(d_eevents[n].data.fd);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
   d_inrun=false;
diff --git a/pdns/ipcipher.cc b/pdns/ipcipher.cc
new file mode 100644 (file)
index 0000000..57a2aa3
--- /dev/null
@@ -0,0 +1,99 @@
+#include "ipcipher.hh"
+#include "ext/ipcrypt/ipcrypt.h"
+#include <openssl/aes.h>
+#include <openssl/evp.h>
+
+/*
+int PKCS5_PBKDF2_HMAC_SHA1(const char *pass, int passlen,
+                           const unsigned char *salt, int saltlen, int iter,
+                           int keylen, unsigned char *out);
+*/
+std::string makeIPCipherKey(const std::string& password)
+{
+  static const char salt[]="ipcipheripcipher";
+  unsigned char out[16];
+
+  PKCS5_PBKDF2_HMAC_SHA1(password.c_str(), password.size(), (const unsigned char*)salt, sizeof(salt)-1, 50000, sizeof(out), out);
+
+  return std::string((const char*)out, (const char*)out + sizeof(out));
+}
+
+static ComboAddress encryptCA4(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  // always returns 0, has no failure mode
+  ipcrypt_encrypt(      (unsigned char*)&ret.sin4.sin_addr.s_addr,
+                 (const unsigned char*)  &ca.sin4.sin_addr.s_addr,
+                 (const unsigned char*)key.c_str());
+  return ret;
+}
+
+static ComboAddress decryptCA4(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  // always returns 0, has no failure mode
+  ipcrypt_decrypt(      (unsigned char*)&ret.sin4.sin_addr.s_addr,
+                 (const unsigned char*)  &ca.sin4.sin_addr.s_addr,
+                 (const unsigned char*)key.c_str());
+  return ret;
+}
+
+
+static ComboAddress encryptCA6(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  AES_KEY wctx;
+  AES_set_encrypt_key((const unsigned char*)key.c_str(), 128, &wctx);
+  AES_encrypt((const unsigned char*)&ca.sin6.sin6_addr.s6_addr,
+              (unsigned char*)&ret.sin6.sin6_addr.s6_addr, &wctx);
+
+  return ret;
+}
+
+static ComboAddress decryptCA6(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+  AES_KEY wctx;
+  AES_set_decrypt_key((const unsigned char*)key.c_str(), 128, &wctx);
+  AES_decrypt((const unsigned char*)&ca.sin6.sin6_addr.s6_addr,
+              (unsigned char*)&ret.sin6.sin6_addr.s6_addr, &wctx);
+
+  return ret;
+}
+
+
+ComboAddress encryptCA(const ComboAddress& ca, const std::string& key)
+{
+  if(ca.sin4.sin_family == AF_INET)
+    return encryptCA4(ca, key);
+  else if(ca.sin4.sin_family == AF_INET6)
+    return encryptCA6(ca, key);
+  else
+    throw std::runtime_error("ipcrypt can't encrypt non-IP addresses");
+}
+
+ComboAddress decryptCA(const ComboAddress& ca, const std::string& key)
+{
+  if(ca.sin4.sin_family == AF_INET)
+    return decryptCA4(ca, key);
+  else if(ca.sin4.sin_family == AF_INET6)
+    return decryptCA6(ca, key);
+  else
+    throw std::runtime_error("ipcrypt can't decrypt non-IP addresses");
+
+}
diff --git a/pdns/ipcipher.hh b/pdns/ipcipher.hh
new file mode 100644 (file)
index 0000000..cbb932d
--- /dev/null
@@ -0,0 +1,9 @@
+#pragma once
+#include "iputils.hh"
+#include <string>
+
+// see https://powerdns.org/ipcipher
+
+ComboAddress encryptCA(const ComboAddress& ca, const std::string& key);
+ComboAddress decryptCA(const ComboAddress& ca, const std::string& key);
+std::string makeIPCipherKey(const std::string& password);
index daf5997d6f465acafe15de29f361caa02ee31a97..84425e66cb08514ecbacfdb92e716f266b77d91f 100644 (file)
@@ -47,7 +47,7 @@ int SSocket(int family, int type, int flags)
 
 int SConnect(int sockfd, const ComboAddress& remote)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     RuntimeError(boost::format("connecting socket to %s: %s") % remote.toStringWithPort() % strerror(savederrno));
@@ -57,12 +57,12 @@ int SConnect(int sockfd, const ComboAddress& remote)
 
 int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     if (savederrno == EINPROGRESS) {
       if (timeout <= 0) {
-        return ret;
+        return savederrno;
       }
 
       /* we wait until the connection has been established */
@@ -97,7 +97,7 @@ int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
     }
   }
 
-  return ret;
+  return 0;
 }
 
 int SBind(int sockfd, const ComboAddress& local)
@@ -269,40 +269,108 @@ void ComboAddress::truncate(unsigned int bits) noexcept
   *place &= (~((1<<bitsleft)-1));
 }
 
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf)
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags)
 {
+  int remainingTime = totalTimeout;
+  time_t start = 0;
+  if (totalTimeout) {
+    start = time(nullptr);
+  }
+
   struct msghdr msgh;
   struct iovec iov;
   char cbuf[256];
+
+  /* Set up iov and msgh structures. */
+  memset(&msgh, 0, sizeof(struct msghdr));
+  msgh.msg_control = nullptr;
+  msgh.msg_controllen = 0;
+  if (dest) {
+    msgh.msg_name = reinterpret_cast<void*>(const_cast<ComboAddress*>(dest));
+    msgh.msg_namelen = dest->getSocklen();
+  }
+  else {
+    msgh.msg_name = nullptr;
+    msgh.msg_namelen = 0;
+  }
+
+  msgh.msg_flags = 0;
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  iov.iov_base = reinterpret_cast<void*>(const_cast<char*>(buffer));
+  iov.iov_len = len;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  msgh.msg_flags = 0;
+
+  size_t sent = 0;
   bool firstTry = true;
-  fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast<char*>(buffer), len, &dest);
-  addCMsgSrcAddr(&msgh, cbuf, &local, localItf);
 
   do {
-    ssize_t written = sendmsg(fd, &msgh, 0);
 
-    if (written > 0)
-      return written;
+#ifdef MSG_FASTOPEN
+    if (flags & MSG_FASTOPEN && firstTry == false) {
+      flags &= ~MSG_FASTOPEN;
+    }
+#endif /* MSG_FASTOPEN */
+
+    ssize_t res = sendmsg(fd, &msgh, flags);
 
-    if (errno == EAGAIN) {
-      if (firstTry) {
-        int res = waitForRWData(fd, false, timeout, 0);
-        if (res > 0) {
-          /* there is room available */
-          firstTry = false;
+    if (res > 0) {
+      size_t written = static_cast<size_t>(res);
+      sent += written;
+
+      if (sent == len) {
+        return sent;
+      }
+
+      /* partial write */
+      iov.iov_len -= written;
+      iov.iov_base = reinterpret_cast<void*>(reinterpret_cast<char*>(iov.iov_base) + written);
+      written = 0;
+    }
+    else if (res == -1) {
+      if (errno == EINTR) {
+        continue;
+      }
+      else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS || errno == ENOTCONN) {
+        /* EINPROGRESS might happen with non blocking socket,
+           especially with TCP Fast Open */
+        if (totalTimeout <= 0 && idleTimeout <= 0) {
+          return sent;
+        }
+
+        if (firstTry) {
+          int res = waitForRWData(fd, false, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime, 0);
+          if (res > 0) {
+            /* there is room available */
+            firstTry = false;
+          }
+          else if (res == 0) {
+            throw runtime_error("Timeout while waiting to write data");
+          } else {
+            throw runtime_error("Error while waiting for room to write data");
+          }
         }
-        else if (res == 0) {
+        else {
           throw runtime_error("Timeout while waiting to write data");
-        } else {
-          throw runtime_error("Error while waiting for room to write data");
         }
       }
       else {
-        throw runtime_error("Timeout while waiting to write data");
+        unixDie("failed in sendMsgWithTimeout");
       }
     }
-    else {
-      unixDie("failed in write2WithTimeout");
+    if (totalTimeout) {
+      time_t now = time(nullptr);
+      int elapsed = now - start;
+      if (elapsed >= remainingTime) {
+        throw runtime_error("Timeout while sending data");
+      }
+      start = now;
+      remainingTime -= elapsed;
     }
   }
   while (firstTry);
index 490e45ac436665556dd77f7e545187cb810ca9a2..6454c342cfbf51f28e9ccf0ea6df61668d6185fe 100644 (file)
@@ -278,7 +278,7 @@ union ComboAddress {
     char host[1024];
     int retval = 0;
     if(sin4.sin_family && !(retval = getnameinfo((struct sockaddr*) this, getSocklen(), host, sizeof(host),0, 0, NI_NUMERICHOST)))
-      return host;
+      return string(host);
     else
       return "invalid "+string(gai_strerror(retval));
   }
@@ -1062,7 +1062,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat
 bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv);
 void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, char* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
 ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to);
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf);
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 bool sendSizeAndMsgWithTimeout(int sock, uint16_t bufferLen, const char* buffer, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 /* requires a non-blocking, connected TCP socket */
 bool isTCPSocketUsable(int sock);
index 01d2505d8ffd885481c24b1af15e1cdc37d26e54..485e720bcd8d2700be491a236f88feb1ed52b321 100644 (file)
 
 string doGetStats();
 
-IXFRDistWebServer::IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl) :
+IXFRDistWebServer::IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl, const string &loglevel) :
   d_ws(std::unique_ptr<WebServer>(new WebServer(listenAddress.toString(), listenAddress.getPort())))
 {
   d_ws->setACL(acl);
+  d_ws->setLogLevel(loglevel);
   d_ws->registerWebHandler("/metrics", boost::bind(&IXFRDistWebServer::getMetrics, this, _1, _2));
   d_ws->bind();
 }
index 580e976344332b3d0a577bf89c36ccc8541de766..e1c1734cc39b14e28d546cb8b2f7095a07091381 100644 (file)
@@ -26,7 +26,7 @@
 class IXFRDistWebServer
 {
   public:
-    explicit IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl);
+    explicit IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl, const string &loglevel);
     void go();
 
   private:
index 49a2d93efa64d609bb1b5ba3126926995127c646..d23d554cbe245d1121880319a5cb95d7e709c59c 100644 (file)
@@ -1138,6 +1138,16 @@ static bool parseAndCheckConfig(const string& configpath, YAML::Node& config) {
     }
   }
 
+  if (config["webserver-loglevel"]) {
+    try {
+      config["webserver-loglevel"].as<string>();
+    }
+    catch (const runtime_error &e) {
+      g_log<<Logger::Error<<"Unable to read 'webserver-loglevel' value: "<<e.what()<<endl;
+      retval = false;
+    }
+  }
+
   return retval;
 }
 
@@ -1280,8 +1290,18 @@ int main(int argc, char** argv) {
       }
     }
 
+    string loglevel = "normal";
+    if (config["webserver-loglevel"]) {
+      loglevel = config["webserver-loglevel"].as<string>();
+    }
+
     // Launch the webserver!
-    std::thread(&IXFRDistWebServer::go, IXFRDistWebServer(config["webserver-address"].as<ComboAddress>(), wsACL)).detach();
+    try {
+      std::thread(&IXFRDistWebServer::go, IXFRDistWebServer(config["webserver-address"].as<ComboAddress>(), wsACL, loglevel)).detach();
+    } catch (const PDNSException &e) {
+      g_log<<Logger::Error<<"Unable to start webserver: "<<e.reason<<endl;
+      had_error = true;
+    }
   }
 
   int newuid = 0;
index b995cb5ef1c45aac8dc060c2cf83e9e57719ec15..976b7f23c6224762d428575ca972eefda90468f5 100644 (file)
@@ -88,6 +88,12 @@ webserver-acl:
   - 127.0.0.0/8
   - ::1/128
 
+# How much the webserver should log: 'none', 'normal' or 'detailed'
+# With 'none', nothing is logged except for errors
+# With 'normal' (the default), one line per request is logged in the style of the common log format
+# with 'detailed', the full requests and responses (including headers) are logged
+webserver-loglevel: normal
+
 # The domains to redistribute, the 'master' and 'domains' keys are mandatory.
 # When no port is specified, 53 is used. When specifying ports for IPv6, use the
 # "bracket" notation:
index 44d3f467354a84374b1a7ac4c20a33eb862f93e9..0c193d1dd44af73b8299799e103096e41d68652f 100644 (file)
@@ -47,7 +47,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -80,9 +80,9 @@ KqueueFDMultiplexer::KqueueFDMultiplexer() : d_kevents(new struct kevent[s_maxev
     throw FDMultiplexerException("Setting up kqueue: "+stringerror());
 }
 
-void KqueueFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void KqueueFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct kevent kqevent;
   EV_SET(&kqevent, fd, (&cbmap == &d_readCallbacks) ? EVFILT_READ : EVFILT_WRITE, EV_ADD, 0,0,0);
@@ -144,14 +144,14 @@ int KqueueFDMultiplexer::run(struct timeval* now, int timeout)
   for(int n=0; n < ret; ++n) {
     d_iter=d_readCallbacks.find(d_kevents[n].ident);
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't find ourselves as writable again
     }
 
     d_iter=d_writeCallbacks.find(d_kevents[n].ident);
 
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
 
index 80a727a376859b8c3eb3756ceea00e0089db7461..4ad23d22460b0688afdbd6283887225ec45fc657 100644 (file)
@@ -322,10 +322,10 @@ void RecursorLua4::postPrepareContext()
   d_lw->registerMember("size", &EDNSOptionViewValue::size);
   d_lw->registerFunction<std::string(EDNSOptionViewValue::*)()>("getContent", [](const EDNSOptionViewValue& value) { return std::string(value.content, value.size); });
   d_lw->registerFunction<size_t(EDNSOptionView::*)()>("count", [](const EDNSOptionView& option) { return option.values.size(); });
-  d_lw->registerFunction<std::vector<std::pair<int, string>>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
-      std::vector<std::pair<int, string> > values;
+  d_lw->registerFunction<std::vector<string>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
+      std::vector<string> values;
       for (const auto& value : option.values) {
-        values.push_back(std::make_pair(values.size(), std::string(value.content, value.size)));
+        values.push_back(std::string(value.content, value.size));
       }
       return values;
     });
index 08c42251cf70df8aeb2dd10c4d6cf7ddf96048b4..9b6eea36d56b2668367ed46974fb98a7b58805e5 100644 (file)
@@ -202,7 +202,7 @@ string nowTime()
   // YYYY-mm-dd HH:MM:SS TZOFF
   strftime(buffer, sizeof(buffer), "%F %T %z", tm);
   buffer[sizeof(buffer)-1] = '\0';
-  return buffer;
+  return string(buffer);
 }
 
 uint16_t getShort(const unsigned char *p)
@@ -498,7 +498,7 @@ string getHostname()
   if(gethostname(tmp, MAXHOSTNAMELEN))
     return "UNKNOWN";
 
-  return tmp;
+  return string(tmp);
 }
 
 string itoa(int i)
@@ -571,7 +571,7 @@ string U32ToIP(uint32_t val)
            (val >> 16)&0xff,
            (val >>  8)&0xff,
            (val      )&0xff);
-  return tmp;
+  return string(tmp);
 }
 
 
index d70143d46be8d9d7e1aabc51aa558cb7b6b691e7..927651c332c5324e43b470715e29438d63922052 100644 (file)
 #include <boost/shared_array.hpp>
 #include <boost/tuple/tuple.hpp>
 #include <boost/tuple/tuple_comparison.hpp>
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/hashed_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
 #include <vector>
 #include <map>
 #include <stdexcept>
 #include <string>
 #include <sys/time.h>
 
+using namespace ::boost::multi_index;
+
 class FDMultiplexerException : public std::runtime_error
 {
 public:
@@ -51,14 +57,15 @@ class FDMultiplexer
 {
 public:
   typedef boost::any funcparam_t;
+  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
 protected:
 
-  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
   struct Callback
   {
     callbackfunc_t d_callback;
-    funcparam_t d_parameter;
+    mutable funcparam_t d_parameter;
     struct timeval d_ttd;
+    int d_fd;
   };
 
 public:
@@ -77,15 +84,15 @@ public:
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) = 0;
 
   //! Add an fd to the read watch list - currently an fd can only be on one list at a time!
-  virtual void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t())
+  virtual void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
   {
-    this->addFD(d_readCallbacks, fd, toDo, parameter);
+    this->addFD(d_readCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Add an fd to the write watch list - currently an fd can only be on one list at a time!
-  virtual void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t())
+  virtual void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
   {
-    this->addFD(d_writeCallbacks, fd, toDo, parameter);
+    this->addFD(d_writeCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Remove an fd from the read watch list. You can't call this function on an fd that is closed already!
@@ -104,25 +111,43 @@ public:
 
   virtual void setReadTTD(int fd, struct timeval tv, int timeout)
   {
-    if(!d_readCallbacks.count(fd))
+    const auto& it = d_readCallbacks.find(fd);
+    if (it == d_readCallbacks.end()) {
       throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
+    }
+
+    auto newEntry = *it;
     tv.tv_sec += timeout;
-    d_readCallbacks[fd].d_ttd=tv;
+    newEntry.d_ttd = tv;
+    d_readCallbacks.replace(it, newEntry);
   }
 
-  virtual funcparam_t& getReadParameter(int fd) 
+  virtual void setWriteTTD(int fd, struct timeval tv, int timeout)
   {
-    if(!d_readCallbacks.count(fd))
-      throw FDMultiplexerException("attempt to look up data in multiplexer for unlisted fd "+std::to_string(fd));
-    return d_readCallbacks[fd].d_parameter;
+    const auto& it = d_writeCallbacks.find(fd);
+    if (it == d_writeCallbacks.end()) {
+      throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
+    }
+
+    auto newEntry = *it;
+    tv.tv_sec += timeout;
+    newEntry.d_ttd = tv;
+    d_writeCallbacks.replace(it, newEntry);
   }
 
-  virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv)
+  virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv, bool writes=false)
   {
     std::vector<std::pair<int, funcparam_t> > ret;
-    for(callbackmap_t::iterator i=d_readCallbacks.begin(); i!=d_readCallbacks.end(); ++i)
-      if(i->second.d_ttd.tv_sec && boost::tie(tv.tv_sec, tv.tv_usec) > boost::tie(i->second.d_ttd.tv_sec, i->second.d_ttd.tv_usec)) 
-        ret.push_back(std::make_pair(i->first, i->second.d_parameter));
+    const auto tied = boost::tie(tv.tv_sec, tv.tv_usec);
+    auto& idx = writes ? d_writeCallbacks.get<TTDOrderedTag>() : d_readCallbacks.get<TTDOrderedTag>();
+
+    for (auto it = idx.begin(); it != idx.end(); ++it) {
+      if (it->d_ttd.tv_sec == 0 || tied <= boost::tie(it->d_ttd.tv_sec, it->d_ttd.tv_usec)) {
+        break;
+      }
+      ret.push_back(std::make_pair(it->d_fd, it->d_parameter));
+    }
+
     return ret;
   }
 
@@ -137,31 +162,77 @@ public:
   
   virtual std::string getName() const = 0;
 
+  size_t getWatchedFDCount(bool writeFDs) const
+  {
+    return writeFDs ? d_writeCallbacks.size() : d_readCallbacks.size();
+  }
+
 protected:
-  typedef std::map<int, Callback> callbackmap_t;
+  struct FDBasedTag {};
+  struct TTDOrderedTag {};
+  struct ttd_compare
+  {
+    /* we want a 0 TTD (no timeout) to come _after_ everything else */
+    bool operator() (const struct timeval& lhs, const struct timeval& rhs) const
+    {
+      /* special treatment if at least one of the TTD is 0,
+         normal comparison otherwise */
+      if (lhs.tv_sec == 0 && rhs.tv_sec == 0) {
+        return false;
+      }
+      if (lhs.tv_sec == 0 && rhs.tv_sec != 0) {
+        return false;
+      }
+      if (lhs.tv_sec != 0 && rhs.tv_sec == 0) {
+        return true;
+      }
+
+      return std::tie(lhs.tv_sec, lhs.tv_usec) < std::tie(rhs.tv_sec, rhs.tv_usec);
+    }
+  };
+
+  typedef multi_index_container<
+    Callback,
+    indexed_by <
+                hashed_unique<tag<FDBasedTag>,
+                              member<Callback,int,&Callback::d_fd>
+                              >,
+                ordered_non_unique<tag<TTDOrderedTag>,
+                                   member<Callback,struct timeval,&Callback::d_ttd>,
+                                   ttd_compare
+                                   >
+               >
+  > callbackmap_t;
+
   callbackmap_t d_readCallbacks, d_writeCallbacks;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)=0;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)=0;
   virtual void removeFD(callbackmap_t& cbmap, int fd)=0;
   bool d_inrun;
   callbackmap_t::iterator d_iter;
 
-  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)
   {
     Callback cb;
+    cb.d_fd = fd;
     cb.d_callback=toDo;
     cb.d_parameter=parameter;
     memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
-  
-    if(cbmap.count(fd))
+    if (ttd) {
+      cb.d_ttd = *ttd;
+    }
+
+    auto pair = cbmap.insert(cb);
+    if (!pair.second) {
       throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
-    cbmap[fd]=cb;
+    }
   }
 
   void accountingRemoveFD(callbackmap_t& cbmap, int fd) 
   {
-    if(!cbmap.erase(fd)) 
+    if(!cbmap.erase(fd)) {
       throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
+    }
   }
 };
 
index 16027b61011f488f82a238e08b9635fa74727a20..2d01f12f272d8088eecfdadfdbf484ab72cd943b 100644 (file)
@@ -274,6 +274,7 @@ template<class Key, class Val>void MTasker<Key,Val>::makeThread(tfunc_t *start,
                                             &uc->uc_stack[uc->uc_stack.size()-1]);
 #endif /* PDNS_USE_VALGRIND */
 
+  ++d_threadsCount;
   auto& thread = d_threads[d_maxtid];
   auto mt = this;
   thread.start = [start, val, mt]() {
@@ -316,6 +317,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::schedule(struct timeval*  n
   }
   if(!d_zombiesQueue.empty()) {
     d_threads.erase(d_zombiesQueue.front());
+    --d_threadsCount;
     d_zombiesQueue.pop();
     return true;
   }
@@ -357,7 +359,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::schedule(struct timeval*  n
  */
 template<class Key, class Val>bool MTasker<Key,Val>::noProcesses() const
 {
-  return d_threads.empty();
+  return d_threadsCount == 0;
 }
 
 //! returns the number of processes running
@@ -366,7 +368,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::noProcesses() const
  */
 template<class Key, class Val>unsigned int MTasker<Key,Val>::numProcesses() const
 {
-  return d_threads.size();
+  return d_threadsCount;
 }
 
 //! gives access to the list of Events threads are waiting for
index 87bc6723cff1dc760c513e880dd2b9f46499f027..0365e756a2690c25ba9141ba7960bc866f532872 100644 (file)
@@ -68,9 +68,10 @@ private:
 
   typedef std::map<int, ThreadInfo> mthreads_t;
   mthreads_t d_threads;
+  size_t d_stacksize;
+  size_t d_threadsCount;
   int d_tid;
   int d_maxtid;
-  size_t d_stacksize;
 
   EventVal d_waitval;
   enum waitstatusenum {Error=-1,TimeOut=0,Answer} d_waitstatus;
@@ -110,7 +111,7 @@ public:
       This limit applies solely to the stack, the heap is not limited in any way. If threads need to allocate a lot of data,
       the use of new/delete is suggested. 
    */
-  MTasker(size_t stacksize=16*8192) : d_tid(0), d_maxtid(0), d_stacksize(stacksize), d_waitstatus(Error)
+  MTasker(size_t stacksize=16*8192) : d_stacksize(stacksize), d_threadsCount(0), d_tid(0), d_maxtid(0), d_waitstatus(Error)
   {
     initMainStackBounds();
 
index 72888474af8ceaf1711e7fbd5214d1453730ebfb..e7189ece004c6515c115c929d42838ab004db1eb 100644 (file)
@@ -1003,7 +1003,7 @@ void PacketHandler::makeNOError(DNSPacket* p, DNSPacket* r, const DNSName& targe
   if(d_dk.isSecuredZone(sd.qname))
     addNSECX(p, r, target, wildcard, sd.qname, mode);
 
-  S.ringAccount("noerror-queries",p->qdomain.toLogString()+"/"+p->qtype.getName());
+  S.ringAccount("noerror-queries", p->qdomain, p->qtype);
 }
 
 
@@ -1561,7 +1561,7 @@ DNSPacket *PacketHandler::doQuestion(DNSPacket *p)
     r=p->replyPacket(); // generate an empty reply packet
     r->setRcode(RCode::ServFail);
     S.inc("servfail-packets");
-    S.ringAccount("servfail-queries",p->qdomain.toLogString());
+    S.ringAccount("servfail-queries", p->qdomain, p->qtype);
   }
   catch(PDNSException &e) {
     g_log<<Logger::Error<<"Backend reported permanent error which prevented lookup ("+e.reason+"), aborting"<<endl;
@@ -1573,7 +1573,7 @@ DNSPacket *PacketHandler::doQuestion(DNSPacket *p)
     r=p->replyPacket(); // generate an empty reply packet
     r->setRcode(RCode::ServFail);
     S.inc("servfail-packets");
-    S.ringAccount("servfail-queries",p->qdomain.toLogString());
+    S.ringAccount("servfail-queries", p->qdomain, p->qtype);
   }
   return r; 
 
index ceb5f42d749134a51e2e0e45bb60b114a20f5452..311d861f5416b53de9fc12c5b5057dda3ff0eb90 100644 (file)
@@ -161,6 +161,8 @@ struct RecThreadInfo
   deferredAdd_t deferredAdds;
   struct ThreadPipeSet pipes;
   std::thread thread;
+  MT_t* mt{nullptr};
+  uint64_t numberOfDistributedQueries{0};
   /* handle the web server, carbon, statistics and the control channel */
   bool isHandler{false};
   /* accept incoming queries (and distributes them to the workers if pdns-distributes-queries is set) */
@@ -226,6 +228,7 @@ static std::set<uint16_t> s_avoidUdpSourcePorts;
 #endif
 static uint16_t s_minUdpSourcePort;
 static uint16_t s_maxUdpSourcePort;
+static double s_balancingFactor;
 
 RecursorControlChannel s_rcc; // only active in the handler thread
 RecursorStats g_stats;
@@ -1626,8 +1629,10 @@ static void startDoResolve(void *p)
         else {
           dc->d_tcpConnection->state=TCPConnection::BYTE0;
           Utility::gettimeofday(&g_now, 0); // needs to be updated
-          t_fdm->addReadFD(dc->d_socket, handleRunningTCPQuestion, dc->d_tcpConnection);
-          t_fdm->setReadTTD(dc->d_socket, g_now, g_tcpTimeout);
+          struct timeval ttd = g_now;
+          ttd.tv_sec += g_tcpTimeout;
+
+          t_fdm->addReadFD(dc->d_socket, handleRunningTCPQuestion, dc->d_tcpConnection, &ttd);
         }
       }
     }
@@ -2047,11 +2052,11 @@ static void handleNewTCPQuestion(int fd, FDMultiplexer::funcparam_t& )
     std::shared_ptr<TCPConnection> tc = std::make_shared<TCPConnection>(newsock, addr);
     tc->state=TCPConnection::BYTE0;
 
-    t_fdm->addReadFD(tc->getFD(), handleRunningTCPQuestion, tc);
+    struct timeval ttd;
+    Utility::gettimeofday(&ttd, 0);
+    ttd.tv_sec += g_tcpTimeout;
 
-    struct timeval now;
-    Utility::gettimeofday(&now, 0);
-    t_fdm->setReadTTD(tc->getFD(), now, g_tcpTimeout);
+    t_fdm->addReadFD(tc->getFD(), handleRunningTCPQuestion, tc, &ttd);
   }
 }
 
@@ -2385,6 +2390,7 @@ static void handleNewUDPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             distributeAsyncFunction(data, boost::bind(doProcessUDPQuestion, data, fromaddr, dest, tv, fd));
           }
           else {
+            ++s_threadInfos[t_id].numberOfDistributedQueries;
             doProcessUDPQuestion(data, fromaddr, dest, tv, fd);
           }
         }
@@ -2633,6 +2639,14 @@ static void doStats(void)
     g_log<<Logger::Notice<<"stats: " <<  broadcastAccFunction<uint64_t>(pleaseGetPacketCacheSize) <<
     " packet cache entries, "<<(int)(100.0*broadcastAccFunction<uint64_t>(pleaseGetPacketCacheHits)/SyncRes::s_queries) << "% packet cache hits"<<endl;
 
+    size_t idx = 0;
+    for (const auto& threadInfo : s_threadInfos) {
+      if(threadInfo.isWorker) {
+        g_log<<Logger::Notice<<"Thread "<<idx<<" has been distributed "<<threadInfo.numberOfDistributedQueries<<" queries"<<endl;
+        ++idx;
+      }
+    }
+
     time_t now = time(0);
     if(lastOutputTime && lastQueryCount && now != lastOutputTime) {
       g_log<<Logger::Notice<<"stats: "<< (SyncRes::s_queries - lastQueryCount) / (now - lastOutputTime) <<" qps (average over "<< (now - lastOutputTime) << " seconds)"<<endl;
@@ -2818,7 +2832,7 @@ void broadcastFunction(const pipefunc_t& func)
 
 static bool trySendingQueryToWorker(unsigned int target, ThreadMSG* tmsg)
 {
-  const auto& targetInfo = s_threadInfos[target];
+  auto& targetInfo = s_threadInfos[target];
   if(!targetInfo.isWorker) {
     g_log<<Logger::Error<<"distributeAsyncFunction() tried to assign a query to a non-worker thread"<<endl;
     exit(1);
@@ -2843,9 +2857,52 @@ static bool trySendingQueryToWorker(unsigned int target, ThreadMSG* tmsg)
     }
   }
 
+  ++targetInfo.numberOfDistributedQueries;
+
   return true;
 }
 
+static unsigned int getWorkerLoad(size_t workerIdx)
+{
+  const auto mt = s_threadInfos[/* skip handler */ 1 + g_numDistributorThreads + workerIdx].mt;
+  if (mt != nullptr) {
+    return mt->numProcesses();
+  }
+  return 0;
+}
+
+static unsigned int selectWorker(unsigned int hash)
+{
+  if (s_balancingFactor == 0) {
+    return /* skip handler */ 1 + g_numDistributorThreads + (hash % g_numWorkerThreads);
+  }
+
+  /* we start with one, representing the query we are currently handling */
+  double currentLoad = 1;
+  std::vector<unsigned int> load(g_numWorkerThreads);
+  for (size_t idx = 0; idx < g_numWorkerThreads; idx++) {
+    load[idx] = getWorkerLoad(idx);
+    currentLoad += load[idx];
+    // cerr<<"load for worker "<<idx<<" is "<<load[idx]<<endl;
+  }
+
+  double targetLoad = (currentLoad / g_numWorkerThreads) * s_balancingFactor;
+  // cerr<<"total load is "<<currentLoad<<", number of workers is "<<g_numWorkerThreads<<", target load is "<<targetLoad<<endl;
+
+  unsigned int worker = hash % g_numWorkerThreads;
+  /* at least one server has to be at or below the average load */
+  if (load[worker] > targetLoad) {
+    ++g_stats.rebalancedQueries;
+    do {
+      // cerr<<"worker "<<worker<<" is above the target load, selecting another one"<<endl;
+      worker = (worker + 1) % g_numWorkerThreads;
+    }
+    while(load[worker] > targetLoad);
+  }
+
+  return /* skip handler */ 1 + g_numDistributorThreads + worker;
+}
+
 // This function is only called by the distributor threads, when pdns-distributes-queries is set
 void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
 {
@@ -2855,7 +2912,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
   }
 
   unsigned int hash = hashQuestion(packet.c_str(), packet.length(), g_disthashseed);
-  unsigned int target = /* skip handler */ 1 + g_numDistributorThreads + (hash % g_numWorkerThreads);
+  unsigned int target = selectWorker(hash);
 
   ThreadMSG* tmsg = new ThreadMSG();
   tmsg->func = func;
@@ -2976,30 +3033,31 @@ template string broadcastAccFunction(const boost::function<string*()>& fun); //
 template uint64_t broadcastAccFunction(const boost::function<uint64_t*()>& fun); // explicit instantiation
 template vector<ComboAddress> broadcastAccFunction(const boost::function<vector<ComboAddress> *()>& fun); // explicit instantiation
 template vector<pair<DNSName,uint16_t> > broadcastAccFunction(const boost::function<vector<pair<DNSName, uint16_t> > *()>& fun); // explicit instantiation
+template ThreadTimes broadcastAccFunction(const boost::function<ThreadTimes*()>& fun);
 
 static void handleRCC(int fd, FDMultiplexer::funcparam_t& var)
 {
-  string remote;
-  string msg=s_rcc.recv(&remote);
-  RecursorControlParser rcp;
-  RecursorControlParser::func_t* command;
+  try {
+    string remote;
+    string msg=s_rcc.recv(&remote);
+    RecursorControlParser rcp;
+    RecursorControlParser::func_t* command;
 
-  string answer=rcp.getAnswer(msg, &command);
+    string answer=rcp.getAnswer(msg, &command);
 
-  // If we are inside a chroot, we need to strip
-  if (!arg()["chroot"].empty()) {
-    size_t len = arg()["chroot"].length();
-    remote = remote.substr(len);
-  }
+    // If we are inside a chroot, we need to strip
+    if (!arg()["chroot"].empty()) {
+      size_t len = arg()["chroot"].length();
+      remote = remote.substr(len);
+    }
 
-  try {
     s_rcc.send(answer, &remote);
     command();
   }
-  catch(std::exception& e) {
+  catch(const std::exception& e) {
     g_log<<Logger::Error<<"Error dealing with control socket request: "<<e.what()<<endl;
   }
-  catch(PDNSException& ae) {
+  catch(const PDNSException& ae) {
     g_log<<Logger::Error<<"Error dealing with control socket request: "<<ae.reason<<endl;
   }
 }
@@ -3630,6 +3688,7 @@ static int serviceMain(int argc, char*argv[])
   }
 
   SyncRes::s_minimumTTL = ::arg().asNum("minimum-ttl-override");
+  SyncRes::s_minimumECSTTL = ::arg().asNum("ecs-minimum-ttl-override");
 
   SyncRes::s_nopacketcache = ::arg().mustDo("disable-packetcache");
 
@@ -3654,6 +3713,9 @@ static int serviceMain(int argc, char*argv[])
   SyncRes::s_ecsipv4limit = ::arg().asNum("ecs-ipv4-bits");
   SyncRes::s_ecsipv6limit = ::arg().asNum("ecs-ipv6-bits");
   SyncRes::clearECSStats();
+  SyncRes::s_ecsipv4cachelimit = ::arg().asNum("ecs-ipv4-cache-bits");
+  SyncRes::s_ecsipv6cachelimit = ::arg().asNum("ecs-ipv6-cache-bits");
+  SyncRes::s_ecscachelimitttl = ::arg().asNum("ecs-cache-limit-ttl");
 
   if (!::arg().isEmpty("ecs-scope-zero-address")) {
     ComboAddress scopeZero(::arg()["ecs-scope-zero-address"]);
@@ -3717,6 +3779,12 @@ static int serviceMain(int argc, char*argv[])
 
   g_statisticsInterval = ::arg().asNum("statistics-interval");
 
+  s_balancingFactor = ::arg().asDouble("distribution-load-factor");
+  if (s_balancingFactor != 0.0 && s_balancingFactor < 1.0) {
+    s_balancingFactor = 0.0;
+    g_log<<Logger::Warning<<"Asked to run with a distribution-load-factor below 1.0, disabling it instead"<<endl;
+  }
+
 #ifdef SO_REUSEPORT
   g_reusePort = ::arg().mustDo("reuseport");
 #endif
@@ -4016,6 +4084,7 @@ try
   }
 
   MT=std::unique_ptr<MTasker<PacketID,string> >(new MTasker<PacketID,string>(::arg().asNum("stack-size")));
+  threadInfo.mt = MT.get();
 
 #ifdef HAVE_PROTOBUF
   /* start protobuf export threads if needed */
@@ -4204,6 +4273,7 @@ int main(int argc, char **argv)
     ::arg().set("webserver-port", "Port of webserver to listen on") = "8082";
     ::arg().set("webserver-password", "Password required for accessing the webserver") = "";
     ::arg().set("webserver-allow-from","Webserver access is only allowed from these subnets")="127.0.0.1,::1";
+    ::arg().set("webserver-loglevel", "Amount of logging in the webserver (none, normal, detailed)") = "normal";
     ::arg().set("carbon-ourname", "If set, overrides our reported hostname for carbon stats")="";
     ::arg().set("carbon-server", "If set, send metrics in carbon (graphite) format to this server IP address")="";
     ::arg().set("carbon-interval", "Number of seconds between carbon (graphite) updates")="30";
@@ -4261,7 +4331,11 @@ int main(int argc, char **argv)
     ::arg().set("latency-statistic-size","Number of latency values to calculate the qa-latency average")="10000";
     ::arg().setSwitch( "disable-packetcache", "Disable packetcache" )= "no";
     ::arg().set("ecs-ipv4-bits", "Number of bits of IPv4 address to pass for EDNS Client Subnet")="24";
+    ::arg().set("ecs-ipv4-cache-bits", "Maximum number of bits of IPv4 mask to cache ECS response")="24";
     ::arg().set("ecs-ipv6-bits", "Number of bits of IPv6 address to pass for EDNS Client Subnet")="56";
+    ::arg().set("ecs-ipv6-cache-bits", "Maximum number of bits of IPv6 mask to cache ECS response")="56";
+    ::arg().set("ecs-minimum-ttl-override", "Set under adverse conditions, a minimum TTL for records in ECS-specific answers")="0";
+    ::arg().set("ecs-cache-limit-ttl", "Minimum TTL to cache ECS response")="0";
     ::arg().set("edns-subnet-whitelist", "List of netmasks and domains that we should enable EDNS subnet for")="";
     ::arg().set("ecs-add-for", "List of client netmasks for which EDNS Client Subnet will be added")="0.0.0.0/0, ::/0, " LOCAL_NETS_INVERSE;
     ::arg().set("ecs-scope-zero-address", "Address to send to whitelisted authoritative servers for incoming queries with ECS prefix-length source of 0")="";
@@ -4314,6 +4388,7 @@ int main(int argc, char **argv)
     ::arg().set("udp-source-port-avoid", "List of comma separated UDP port number to avoid")="11211";
     ::arg().set("rng", "Specify random number generator to use. Valid values are auto,sodium,openssl,getrandom,arc4random,urandom.")="auto";
     ::arg().set("public-suffix-list-file", "Path to the Public Suffix List file, if any")="";
+    ::arg().set("distribution-load-factor", "The load factor used when PowerDNS is distributing queries to worker threads")="0.0";
 #ifdef NOD_ENABLED
     ::arg().set("new-domain-tracking", "Track newly observed domains (i.e. never seen before).")="no";
     ::arg().set("new-domain-log", "Log newly observed domains.")="yes";
index cd753330f5d51590d82b825f6bf50e9289bfed4b..9556faf19f5626028f9c8b7a73e865807509c518 100644 (file)
@@ -20,6 +20,7 @@
 #include "zoneparser-tng.hh"
 #include "signingpipe.hh"
 #include "dns_random.hh"
+#include "ipcipher.hh"
 #include <fstream>
 #include <termios.h>            //termios, TCSANOW, ECHO, ICANON
 #include "opensslsigners.hh"
@@ -76,7 +77,7 @@ void loadMainConfig(const std::string& configdir)
     exit(0);
   }
 
-  if(::arg()["config-name"]!="") 
+  if(::arg()["config-name"]!="")
     s_programname+="-"+::arg()["config-name"];
 
   string configname=::arg()["config-dir"]+"/"+s_programname+".conf";
@@ -112,22 +113,22 @@ void loadMainConfig(const std::string& configdir)
   g_log.toConsole(Logger::Error);   // so we print any errors
   BackendMakers().launch(::arg()["launch"]); // vrooooom!
   if(::arg().asNum("loglevel") >= 3) // so you can't kill our errors
-    g_log.toConsole((Logger::Urgency)::arg().asNum("loglevel"));  
+    g_log.toConsole((Logger::Urgency)::arg().asNum("loglevel"));
 
   //cerr<<"Backend: "<<::arg()["launch"]<<", '" << ::arg()["gmysql-dbname"] <<"'" <<endl;
 
   S.declare("qsize-q","Number of questions waiting for database attention");
-          
+
   ::arg().set("max-cache-entries", "Maximum number of cache entries")="1000000";
-  ::arg().set("cache-ttl","Seconds to store packets in the PacketCache")="20";              
+  ::arg().set("cache-ttl","Seconds to store packets in the PacketCache")="20";
   ::arg().set("negquery-cache-ttl","Seconds to store negative query results in the QueryCache")="60";
-  ::arg().set("query-cache-ttl","Seconds to store query results in the QueryCache")="20";              
+  ::arg().set("query-cache-ttl","Seconds to store query results in the QueryCache")="20";
   ::arg().set("default-soa-name","name to insert in the SOA record if none set in the backend")="a.misconfigured.powerdns.server";
   ::arg().set("default-soa-mail","mail address to insert in the SOA record if none set in the backend")="";
   ::arg().set("soa-refresh-default","Default SOA refresh")="10800";
   ::arg().set("soa-retry-default","Default SOA retry")="3600";
   ::arg().set("soa-expire-default","Default SOA expire")="604800";
-  ::arg().set("soa-minimum-ttl","Default SOA minimum ttl")="3600";    
+  ::arg().set("soa-minimum-ttl","Default SOA minimum ttl")="3600";
   ::arg().set("chroot","Switch to this chroot jail")="";
   ::arg().set("dnssec-key-cache-ttl","Seconds to cache DNSSEC keys from the database")="30";
   ::arg().set("domain-metadata-cache-ttl","Seconds to cache domain metadata from the database")="60";
@@ -322,7 +323,7 @@ int checkZone(DNSSECKeeper &dk, UeberBackend &B, const DNSName& zone, const vect
       records.push_back(drr);
     }
   }
-  else 
+  else
     records=*suppliedrecords;
 
   for(auto &rr : records) { // we modify this
@@ -665,7 +666,7 @@ int increaseSerial(const DNSName& zone, DNSSECKeeper &dk)
     NSEC3PARAMRecordContent ns3pr;
     bool narrow;
     bool haveNSEC3=dk.getNSEC3PARAM(zone, &ns3pr, &narrow);
-  
+
     DNSName ordername;
     if(haveNSEC3) {
       if(!narrow)
@@ -716,7 +717,7 @@ void listKey(DomainInfo const &di, DNSSECKeeper& dk, bool printHeader = true) {
     spacelen = (std::to_string(key.first.getKey()->getBits()).length() >= 8) ? 1 : 8 - std::to_string(key.first.getKey()->getBits()).length();
     if (key.first.getKey()->getBits() < 1) {
       cout<<"invalid "<<endl;
-      continue; 
+      continue;
     } else {
       cout<<key.first.getKey()->getBits()<<string(spacelen, ' ');
     }
@@ -778,7 +779,7 @@ int listKeys(const string &zname, DNSSECKeeper& dk){
 int listZone(const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -790,9 +791,9 @@ int listZone(const DNSName &zone) {
   
   while(di.backend->get(rr)) {
     if(rr.qtype.getCode()) {
-      if ( (rr.qtype.getCode() == QType::NS || rr.qtype.getCode() == QType::SRV || rr.qtype.getCode() == QType::MX || rr.qtype.getCode() == QType::CNAME) && !rr.content.empty() && rr.content[rr.content.size()-1] != '.') 
+      if ( (rr.qtype.getCode() == QType::NS || rr.qtype.getCode() == QType::SRV || rr.qtype.getCode() == QType::MX || rr.qtype.getCode() == QType::CNAME) && !rr.content.empty() && rr.content[rr.content.size()-1] != '.')
        rr.content.append(1, '.');
-       
+
       cout<<rr.qname<<"\t"<<rr.ttl<<"\tIN\t"<<rr.qtype.getName()<<"\t"<<rr.content<<"\n";
     }
   }
@@ -801,8 +802,8 @@ int listZone(const DNSName &zone) {
 }
 
 // lovingly copied from http://stackoverflow.com/questions/1798511/how-to-avoid-press-enter-with-any-getchar
-int read1char(){   
-    int c;   
+int read1char(){
+    int c;
     static struct termios oldt, newt;
 
     /*tcgetattr gets the parameters of the current terminal
@@ -814,7 +815,7 @@ int read1char(){
 
     /*ICANON normally takes care that one line at a time will be processed
     that means it will return if it sees a "\n" or an EOF or an EOL*/
-    newt.c_lflag &= ~(ICANON);          
+    newt.c_lflag &= ~(ICANON);
 
     /*Those new settings will be set to STDIN
     TCSANOW tells tcsetattr to change attributes immediately. */
@@ -831,7 +832,7 @@ int read1char(){
 int clearZone(DNSSECKeeper& dk, const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -847,7 +848,7 @@ int clearZone(DNSSECKeeper& dk, const DNSName &zone) {
 int editZone(DNSSECKeeper& dk, const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -937,7 +938,7 @@ int editZone(DNSSECKeeper& dk, const DNSName &zone) {
     cerr<<"\n";
     if(c!='a')
       post.clear();
-    if(c=='e') 
+    if(c=='e')
       goto editMore;
     else if(c=='r')
       goto editAgain;
@@ -1002,6 +1003,20 @@ int editZone(DNSSECKeeper& dk, const DNSName &zone) {
   return EXIT_SUCCESS;
 }
 
+static int xcryptIP(const std::string& cmd, const std::string& ip, const std::string& rkey)
+{
+
+  ComboAddress ca(ip), ret;
+
+  if(cmd=="ipencrypt")
+    ret = encryptCA(ca, rkey);
+  else
+    ret = decryptCA(ca, rkey);
+
+  cout<<ret.toString()<<endl;
+  return EXIT_SUCCESS;
+}
+
 
 int loadZone(DNSName zone, const string& fname) {
   UeberBackend B;
@@ -1013,7 +1028,7 @@ int loadZone(DNSName zone, const string& fname) {
   else {
     cerr<<"Creating '"<<zone<<"'"<<endl;
     B.createDomain(zone);
-    
+
     if(!B.getDomainInfo(zone, di)) {
       cerr<<"Domain '"<<zone<<"' was not created - perhaps backend ("<<::arg()["launch"]<<") does not support storing new zones."<<endl;
       return EXIT_FAILURE;
@@ -1021,13 +1036,13 @@ int loadZone(DNSName zone, const string& fname) {
   }
   DNSBackend* db = di.backend;
   ZoneParserTNG zpt(fname, zone);
-  
+
   DNSResourceRecord rr;
   if(!db->startTransaction(zone, di.id)) {
     cerr<<"Unable to start transaction for load of zone '"<<zone<<"'"<<endl;
     return EXIT_FAILURE;
   }
-  rr.domain_id=di.id;  
+  rr.domain_id=di.id;
   bool haveSOA = false;
   while(zpt.get(rr)) {
     if(!rr.qname.isPartOf(zone) && rr.qname!=zone) {
@@ -1065,7 +1080,7 @@ int createZone(const DNSName &zone, const DNSName& nsname) {
   rr.auth = 1;
   rr.ttl = ::arg().asNum("default-ttl");
   rr.qtype = "SOA";
-  
+
   string soa = (boost::format("%s %s 1")
                 % (nsname.empty() ? ::arg()["default-soa-name"] : nsname.toString())
                 % (::arg().isEmpty("default-soa-mail") ? (DNSName("hostmaster.") + zone).toString() : ::arg()["default-soa-mail"])
@@ -1082,7 +1097,7 @@ int createZone(const DNSName &zone, const DNSName& nsname) {
     rr.content=nsname.toStringNoDot();
     di.backend->feedRecord(rr, DNSName());
   }
-  
+
   di.backend->commitTransaction();
 
   return EXIT_SUCCESS;
@@ -1144,7 +1159,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
   DNSName name;
   if(cmds[2]=="@")
     name=zone;
-  else 
+  else
     name=DNSName(cmds[2])+zone;
 
   rr.qtype = DNSRecordContent::TypeToNumber(cmds[3]);
@@ -1212,7 +1227,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
     newrrs.push_back(rr);
   }
 
-  
+
   di.backend->replaceRRSet(di.id, name, rr.qtype, newrrs);
   // need to be explicit to bypass the ueberbackend cache!
   di.backend->lookup(rr.qtype, name, 0, di.id);
@@ -1224,7 +1239,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
 }
 
 // delete-rrset zone name type
-int deleteRRSet(const std::string& zone_, const std::string& name_, const std::string& type_) 
+int deleteRRSet(const std::string& zone_, const std::string& name_, const std::string& type_)
 {
   UeberBackend B;
   DomainInfo di;
@@ -1237,7 +1252,7 @@ int deleteRRSet(const std::string& zone_, const std::string& name_, const std::s
   DNSName name;
   if(name_=="@")
     name=zone;
-  else 
+  else
     name=DNSName(name_)+zone;
 
   QType qt(QType::chartocode(type_.c_str()));
@@ -1302,16 +1317,16 @@ void testSpeed(DNSSECKeeper& dk, const DNSName& zone, const string& remote, int
   rr.ttl=3600;
   rr.auth=1;
   rr.qclass = QClass::IN;
-  
+
   UeberBackend db("key-only");
-  
+
   if ( ! db.backends.size() )
   {
     throw runtime_error("No backends available for DNSSEC key storage");
   }
 
   ChunkedSigningPipe csp(DNSName(zone), 1, cores);
-  
+
   vector<DNSZoneRecord> signatures;
   uint32_t rnd;
   unsigned char* octets = (unsigned char*)&rnd;
@@ -1320,10 +1335,10 @@ void testSpeed(DNSSECKeeper& dk, const DNSName& zone, const string& remote, int
   dt.set();
   for(unsigned int n=0; n < 100000; ++n) {
     rnd = dns_random(UINT32_MAX);
-    snprintf(tmp, sizeof(tmp), "%d.%d.%d.%d", 
+    snprintf(tmp, sizeof(tmp), "%d.%d.%d.%d",
       octets[0], octets[1], octets[2], octets[3]);
     rr.content=tmp;
-    
+
     snprintf(tmp, sizeof(tmp), "r-%u", rnd);
     rr.qname=DNSName(tmp)+zone;
     DNSZoneRecord dzr;
@@ -1369,7 +1384,7 @@ void verifyCrypto(const string& zone)
       toSign.push_back(DNSRecordContent::mastermake(rr.qtype.getCode(), 1, rr.content));
     }
   }
-  
+
   string msg = getMessageForRRSET(qname, rrc, toSign);
   cerr<<"Verify: "<<DNSCryptoKeyEngine::makeFromPublicKeyString(drc.d_algorithm, drc.d_key)->verify(msg, rrc.d_signature)<<endl;
   if(dsrc.d_digesttype) {
@@ -1871,7 +1886,7 @@ int addOrSetMeta(const DNSName& zone, const string& kind, const vector<string>&
 
 int main(int argc, char** argv)
 try
-{  
+{
   po::options_description desc("Allowed options");
   desc.add_options()
     ("help,h", "produce help message")
@@ -1949,13 +1964,15 @@ try
 #ifdef HAVE_P11KIT1
     cout<<"hsm assign ZONE ALGORITHM {ksk|zsk} MODULE SLOT PIN LABEL"<<endl<<
           "                                   Assign a hardware signing module to a ZONE"<<endl;
-    cout<<"hsm create-key ZONE KEY-ID [BITS]  Create a key using hardware signing module for ZONE (use assign first)"<<endl; 
+    cout<<"hsm create-key ZONE KEY-ID [BITS]  Create a key using hardware signing module for ZONE (use assign first)"<<endl;
     cout<<"                                   BITS defaults to 2048"<<endl;
 #endif
     cout<<"increase-serial ZONE               Increases the SOA-serial by 1. Uses SOA-EDIT"<<endl;
     cout<<"import-tsig-key NAME ALGORITHM KEY Import TSIG key"<<endl;
     cout<<"import-zone-key ZONE FILE          Import from a file a private key, ZSK or KSK"<<endl;
-    cout<<"       [active|inactive] [ksk|zsk]  Defaults to KSK and active"<<endl;
+    cout<<"       [active|inactive] [ksk|zsk] Defaults to KSK and active"<<endl;
+    cout<<"ipdecrypt IP passphrase/key [key]  Encrypt IP address using passphrase or base64 key"<<endl;
+    cout<<"ipencrypt IP passphrase/key [key]  Encrypt IP address using passphrase or base64 key"<<endl;
     cout<<"load-zone ZONE FILE                Load ZONE from FILE, possibly creating zone or atomically"<<endl;
     cout<<"                                   replacing contents"<<endl;
     cout<<"list-algorithms [with-backend]     List all DNSSEC algorithms supported, optionally also listing the crypto library used"<<endl;
@@ -1977,7 +1994,7 @@ try
     cout<<"set-presigned ZONE                 Use presigned RRSIGs from storage"<<endl;
     cout<<"set-publish-cdnskey ZONE           Enable sending CDNSKEY responses for ZONE"<<endl;
     cout<<"set-publish-cds ZONE [DIGESTALGOS] Enable sending CDS responses for ZONE, using DIGESTALGOS as signature algorithms"<<endl;
-    cout<<"                                   DIGESTALGOS should be a comma separated list of numbers, is is '1,2' by default"<<endl;
+    cout<<"                                   DIGESTALGOS should be a comma separated list of numbers, it is '1,2' by default"<<endl;
     cout<<"add-meta ZONE KIND VALUE           Add zone metadata, this adds to the existing KIND"<<endl;
     cout<<"                   [VALUE ...]"<<endl;
     cout<<"set-meta ZONE KIND [VALUE] [VALUE] Set zone metadata, optionally providing a value. *No* value clears meta"<<endl;
@@ -2004,6 +2021,25 @@ try
     return 1;
   }
 
+  if(cmds[0] == "ipencrypt" || cmds[0]=="ipdecrypt") {
+    if(cmds.size() < 3 || (cmds.size()== 4 && cmds[3]!="key")) {
+      cerr<<"Syntax: pdnsutil [ipencrypt|ipdecrypt] IP passphrase [key]"<<endl;
+      return 0;
+    }
+    string key;
+    if(cmds.size()==4) {
+      if(B64Decode(cmds[2], key) < 0) {
+        cerr<<"Could not parse '"<<cmds[3]<<"' as base64"<<endl;
+        return 0;
+      }
+    }
+    else {
+      key = makeIPCipherKey(cmds[2]);
+    }
+    exit(xcryptIP(cmds[0], cmds[1], key));
+  }
+
+
   if(cmds[0] == "test-algorithms") {
     if (testAlgorithms())
       return 0;
@@ -2071,8 +2107,8 @@ try
       return 0;
     }
     unsigned int exitCode = 0;
-    for(unsigned int n = 1; n < cmds.size(); ++n) 
-      if (!rectifyZone(dk, DNSName(cmds[n]))) 
+    for(unsigned int n = 1; n < cmds.size(); ++n)
+      if (!rectifyZone(dk, DNSName(cmds[n])))
        exitCode = 1;
     return exitCode;
   }
@@ -2240,7 +2276,7 @@ try
         active=false;
       } else if(pdns_stou(cmds[n])) {
         bits = pdns_stou(cmds[n]);
-      } else { 
+      } else {
         cerr<<"Unknown algorithm, key flag or size '"<<cmds[n]<<"'"<<endl;
         exit(EXIT_FAILURE);;
       }
@@ -2391,7 +2427,7 @@ try
       }
       dk.commitTransaction();
     }
-    
+
     for(const auto& zone : mustRectify)
       rectifyZone(dk, zone);
 
@@ -2489,7 +2525,7 @@ try
   else if(cmds[0]=="set-presigned") {
     if(cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil set-presigned ZONE"<<endl;
-      return 0; 
+      return 0;
     }
     if (! dk.setPresigned(DNSName(cmds[1]))) {
       cerr << "Could not set presigned for " << cmds[1] << " (is DNSSEC enabled in your backend?)" << endl;
@@ -2527,7 +2563,7 @@ try
   else if(cmds[0]=="unset-presigned") {
     if(cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil unset-presigned ZONE"<<endl;
-      return 0;  
+      return 0;
     }
     if (! dk.unsetPresigned(DNSName(cmds[1]))) {
       cerr << "Could not unset presigned on for " << cmds[1] << endl;
@@ -2573,7 +2609,7 @@ try
     if(narrow) {
       cerr<<"The '"<<zone<<"' zone uses narrow NSEC3, but calculating hash anyhow"<<endl;
     }
-      
+
     cout<<toBase32Hex(hashQNameWithSalt(ns3pr, record))<<endl;
   }
   else if(cmds[0]=="unset-nsec3") {
@@ -2597,7 +2633,7 @@ try
     unsigned int id=pdns_stou(cmds[2]);
     DNSSECPrivateKey dpk=dk.getKeyById(DNSName(zone), id);
     cout << dpk.getKey()->convertToISC() <<endl;
-  }  
+  }
   else if(cmds[0]=="increase-serial") {
     if (cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil increase-serial ZONE"<<endl;
@@ -2626,14 +2662,14 @@ try
     DNSKEYRecordContent drc;
     shared_ptr<DNSCryptoKeyEngine> key(DNSCryptoKeyEngine::makeFromPEMString(drc, raw));
     dpk.setKey(key);
-    
+
     dpk.d_algorithm = pdns_stou(cmds[3]);
-    
+
     if(dpk.d_algorithm == DNSSECKeeper::RSASHA1NSEC3SHA1)
       dpk.d_algorithm = DNSSECKeeper::RSASHA1;
-      
+
     cerr<<(int)dpk.d_algorithm<<endl;
-    
+
     if(cmds.size() > 4) {
       if(pdns_iequals(cmds[4], "ZSK"))
         dpk.d_flags = 256;
@@ -2659,7 +2695,7 @@ try
     } else {
       cout<<std::to_string(id)<<endl;
     }
-    
+
   }
   else if(cmds[0]=="import-zone-key") {
     if(cmds.size() < 3) {
@@ -2673,11 +2709,11 @@ try
     shared_ptr<DNSCryptoKeyEngine> key(DNSCryptoKeyEngine::makeFromISCFile(drc, fname.c_str()));
     dpk.setKey(key);
     dpk.d_algorithm = drc.d_algorithm;
-    
+
     if(dpk.d_algorithm == DNSSECKeeper::RSASHA1NSEC3SHA1)
       dpk.d_algorithm = DNSSECKeeper::RSASHA1;
-    
-    dpk.d_flags = 257; 
+
+    dpk.d_flags = 257;
     bool active=true;
 
     for(unsigned int n = 3; n < cmds.size(); ++n) {
@@ -2689,10 +2725,10 @@ try
         active = 1;
       else if(pdns_iequals(cmds[n], "passive") || pdns_iequals(cmds[n], "inactive")) // passive eventually needs to be removed
         active = 0;
-      else { 
+      else {
         cerr<<"Unknown key flag '"<<cmds[n]<<"'"<<endl;
         exit(1);
-      }          
+      }
     }
     int64_t id;
     if (!dk.addKey(DNSName(zone), dpk, id, active)) {
@@ -2770,14 +2806,14 @@ try
         }
       }
     }
-    dpk->create(bits); 
-    dspk.setKey(dpk); 
-    dspk.d_algorithm = algorithm; 
-    dspk.d_flags = keyOrZone ? 257 : 256; 
+    dpk->create(bits);
+    dspk.setKey(dpk);
+    dspk.d_algorithm = algorithm;
+    dspk.d_flags = keyOrZone ? 257 : 256;
 
-    // print key to stdout 
-    cout << "Flags: " << dspk.d_flags << endl << 
-             dspk.getKey()->convertToISC() << endl; 
+    // print key to stdout
+    cout << "Flags: " << dspk.d_flags << endl <<
+             dspk.getKey()->convertToISC() << endl;
   } else if (cmds[0]=="generate-tsig-key") {
     string usage = "Syntax: " + cmds[0] + " name (hmac-md5|hmac-sha1|hmac-sha224|hmac-sha256|hmac-sha384|hmac-sha512)";
     if (cmds.size() < 3) {
@@ -2860,7 +2896,7 @@ try
         return 1;
      }
      UeberBackend B("default");
-     std::vector<std::string> meta; 
+     std::vector<std::string> meta;
      if (!B.getDomainMetadata(zname, metaKey, meta)) {
        cerr << "Failure enabling TSIG key " << name << " for " << zname << endl;
        return 1;
@@ -2942,7 +2978,7 @@ try
       for(const auto& each_meta: meta) {
         cout << each_meta.first << " = " << boost::join(each_meta.second, ", ") << endl;
       }
-    }  
+    }
     return 0;
 
   } else if (cmds[0]=="set-meta" || cmds[0]=="add-meta") {
@@ -3008,11 +3044,11 @@ try
          pub_label = label;
 
       std::ostringstream iscString;
-      iscString << "Private-key-format: v1.2" << std::endl << 
-        "Algorithm: " << algorithm << std::endl << 
+      iscString << "Private-key-format: v1.2" << std::endl <<
+        "Algorithm: " << algorithm << std::endl <<
         "Engine: " << module << std::endl <<
         "Slot: " << slot << std::endl <<
-        "PIN: " << pin << std::endl << 
+        "PIN: " << pin << std::endl <<
         "Label: " << label << std::endl <<
         "PubLabel: " << pub_label << std::endl;
 
@@ -3067,19 +3103,19 @@ try
         cerr << "Unable to create key for unknown zone '" << zone << "'" << std::endl;
         return 1;
       }
+
       id = pdns_stou(cmds[3]);
-      std::vector<DNSBackend::KeyData> keys; 
+      std::vector<DNSBackend::KeyData> keys;
       if (!B.getDomainKeys(zone, keys)) {
         cerr << "No keys found for zone " << zone << std::endl;
         return 1;
-      } 
+      }
 
       std::shared_ptr<DNSCryptoKeyEngine> dke = nullptr;
-      // lookup correct key      
+      // lookup correct key
       for(DNSBackend::KeyData &kd :  keys) {
         if (kd.id == id) {
-          // found our key. 
+          // found our key.
           DNSKEYRecordContent dkrc;
           dke = DNSCryptoKeyEngine::makeFromISCString(dkrc, kd.content);
         }
@@ -3108,7 +3144,7 @@ try
     }
 #else
     cerr<<"PKCS#11 support not enabled"<<endl;
-    return 1; 
+    return 1;
 #endif
   } else if (cmds[0] == "b2b-migrate") {
     if (cmds.size() < 3) {
index 312c4bfece6919adb7db0e84cf0f50e8014d2064..399ad132ae8a3bd6dca4d566d2ed03fa952584e4 100644 (file)
@@ -37,7 +37,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
 
   string getName() const override
@@ -60,20 +60,14 @@ static struct RegisterOurselves
   }
 } doIt;
 
-void PollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void PollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  Callback cb;
-  cb.d_callback=toDo;
-  cb.d_parameter=parameter;
-  memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
-  if(cbmap.count(fd))
-    throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
-  cbmap[fd]=cb;
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 }
 
 void PollFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
 {
-  if(d_inrun && d_iter->first==fd)  // trying to remove us!
+  if(d_inrun && d_iter->d_fd==fd)  // trying to remove us!
     ++d_iter;
 
   if(!cbmap.erase(fd))
@@ -87,13 +81,13 @@ vector<struct pollfd> PollFDMultiplexer::preparePollFD() const
 
   struct pollfd pollfd;
   for(const auto& cb : d_readCallbacks) {
-    pollfd.fd = cb.first;
+    pollfd.fd = cb.d_fd;
     pollfd.events = POLLIN;
     pollfds.push_back(pollfd);
   }
 
   for(const auto& cb : d_writeCallbacks) {
-    pollfd.fd = cb.first;
+    pollfd.fd = cb.d_fd;
     pollfd.events = POLLOUT;
     pollfds.push_back(pollfd);
   }
@@ -110,7 +104,7 @@ void PollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
     throw FDMultiplexerException("poll returned error: " + stringerror());
 
   for(const auto& pollfd : pollfds) {
-    if (pollfd.revents == POLLIN || pollfd.revents == POLLOUT) {
+    if (pollfd.revents & POLLIN || pollfd.revents & POLLOUT) {
       fds.push_back(pollfd.fd);
     }
   }
@@ -134,19 +128,19 @@ int PollFDMultiplexer::run(struct timeval* now, int timeout)
   d_inrun=true;
 
   for(const auto& pollfd : pollfds) {
-    if(pollfd.revents == POLLIN) {
+    if(pollfd.revents & POLLIN) {
       d_iter=d_readCallbacks.find(pollfd.fd);
     
       if(d_iter != d_readCallbacks.end()) {
-        d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
         continue; // so we don't refind ourselves as writable!
       }
     }
-    else if(pollfd.revents == POLLOUT) {
+    else if(pollfd.revents & POLLOUT) {
       d_iter=d_writeCallbacks.find(pollfd.fd);
     
       if(d_iter != d_writeCallbacks.end()) {
-        d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       }
     }
   }
index c6e91e60fc25d71e7a55e11890c36e8a59706120..39939b2f4fbc437f491ed0a7561613a18f00e657 100644 (file)
@@ -25,7 +25,7 @@ public:
 
   virtual int run(struct timeval* tv, int timeout=500);
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter);
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr);
   virtual void removeFD(callbackmap_t& cbmap, int fd);
   string getName()
   {
@@ -59,9 +59,9 @@ PortsFDMultiplexer::PortsFDMultiplexer() : d_pevents(new port_event_t[s_maxevent
     throw FDMultiplexerException("Setting up port: "+stringerror());
 }
 
-void PortsFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void PortsFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   if(port_associate(d_portfd, PORT_SOURCE_FD, fd, (&cbmap == &d_readCallbacks) ? POLLIN : POLLOUT, 0) < 0) {
     cbmap.erase(fd);
@@ -113,7 +113,7 @@ int PortsFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(d_pevents[n].portev_object);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       if(d_readCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
                         POLLIN, 0) < 0)
         throw FDMultiplexerException("Unable to add fd back to ports (read): "+stringerror());
@@ -123,7 +123,7 @@ int PortsFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_writeCallbacks.find(d_pevents[n].portev_object);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       if(d_writeCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
                         POLLOUT, 0) < 0)
         throw FDMultiplexerException("Unable to add fd back to ports (write): "+stringerror());
index 415a36ff30de362efb4bac4ee8fda4da7f5452ae..45b48cea8deb9dc598866ff423ac1de39cf74550 100644 (file)
@@ -114,6 +114,7 @@ static const oid dnssecAuthenticDataQueriesOID[] = { RECURSOR_STATS_OID, 95 };
 static const oid dnssecCheckDisabledQueriesOID[] = { RECURSOR_STATS_OID, 96 };
 static const oid variableResponsesOID[] = { RECURSOR_STATS_OID, 97 };
 static const oid specialMemoryUsageOID[] = { RECURSOR_STATS_OID, 98 };
+static const oid rebalancedQueriesOID[] = { RECURSOR_STATS_OID, 99 };
 
 static std::unordered_map<oid, std::string> s_statsMap;
 
@@ -320,5 +321,6 @@ RecursorSNMPAgent::RecursorSNMPAgent(const std::string& name, const std::string&
   registerCounter64Stat("policy-result-truncate", policyResultTruncateOID, OID_LENGTH(policyResultTruncateOID));
   registerCounter64Stat("policy-result-custom", policyResultCustomOID, OID_LENGTH(policyResultCustomOID));
   registerCounter64Stat("special-memory-usage", specialMemoryUsageOID, OID_LENGTH(specialMemoryUsageOID));
+  registerCounter64Stat("rebalanced-queries", rebalancedQueriesOID, OID_LENGTH(rebalancedQueriesOID));
 #endif /* HAVE_NET_SNMP */
 }
index 4b101b638a3f14e5a74d8ce8866163a2300d5ad9..77bbd872e83c431d7425ef4a1157f8f4cd27d660 100644 (file)
@@ -694,10 +694,29 @@ static string getTAs()
 template<typename T>
 static string setMinimumTTL(T begin, T end)
 {
-  if(end-begin != 1) 
+  if(end-begin != 1)
     return "Need to supply new minimum TTL number\n";
-  SyncRes::s_minimumTTL = pdns_stou(*begin);
-  return "New minimum TTL: " + std::to_string(SyncRes::s_minimumTTL) + "\n";
+  try {
+    SyncRes::s_minimumTTL = pdns_stou(*begin);
+    return "New minimum TTL: " + std::to_string(SyncRes::s_minimumTTL) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new minimum TTL number: " + std::string(e.what()) + "\n";
+  }
+}
+
+template<typename T>
+static string setMinimumECSTTL(T begin, T end)
+{
+  if(end-begin != 1)
+    return "Need to supply new ECS minimum TTL number\n";
+  try {
+    SyncRes::s_minimumECSTTL = pdns_stou(*begin);
+    return "New minimum ECS TTL: " + std::to_string(SyncRes::s_minimumECSTTL) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new ECS minimum TTL number: " + std::string(e.what()) + "\n";
+  }
 }
 
 template<typename T>
@@ -705,8 +724,13 @@ static string setMaxCacheEntries(T begin, T end)
 {
   if(end-begin != 1) 
     return "Need to supply new cache size\n";
-  g_maxCacheEntries = pdns_stou(*begin);
-  return "New max cache entries: " + std::to_string(g_maxCacheEntries) + "\n";
+  try {
+    g_maxCacheEntries = pdns_stou(*begin);
+    return "New max cache entries: " + std::to_string(g_maxCacheEntries) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new cache size: " + std::string(e.what()) + "\n";
+  }
 }
 
 template<typename T>
@@ -714,8 +738,13 @@ static string setMaxPacketCacheEntries(T begin, T end)
 {
   if(end-begin != 1) 
     return "Need to supply new packet cache size\n";
-  g_maxPacketCacheEntries = pdns_stou(*begin);
-  return "New max packetcache entries: " + std::to_string(g_maxPacketCacheEntries) + "\n";
+  try {
+    g_maxPacketCacheEntries = pdns_stou(*begin);
+    return "New max packetcache entries: " + std::to_string(g_maxPacketCacheEntries) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new packet cache size: " + std::string(e.what()) + "\n";
+  }
 }
 
 
@@ -733,6 +762,48 @@ static uint64_t getUserTimeMsec()
   return (ru.ru_utime.tv_sec*1000ULL + ru.ru_utime.tv_usec/1000);
 }
 
+/* This is a pretty weird set of functions. To get per-thread cpu usage numbers,
+   we have to ask a thread over a pipe. We could do so surgically, so if you want to know about
+   thread 3, we pick pipe 3, but we lack that infrastructure.
+
+   We can however ask "execute this function on all threads and add up the results".
+   This is what the first function does using a custom object ThreadTimes, which if you add
+   to each other keeps filling the first one with CPU usage numbers
+*/
+
+static ThreadTimes* pleaseGetThreadCPUMsec()
+{
+  uint64_t ret=0;
+#ifdef RUSAGE_THREAD
+  struct rusage ru;
+  getrusage(RUSAGE_THREAD, &ru);
+  ret = (ru.ru_utime.tv_sec*1000ULL + ru.ru_utime.tv_usec/1000);
+  ret += (ru.ru_stime.tv_sec*1000ULL + ru.ru_stime.tv_usec/1000);
+#endif
+  return new ThreadTimes{ret};
+}
+
+/* Next up, when you want msec data for a specific thread, we check
+   if we recently executed pleaseGetThreadCPUMsec. If we didn't we do so
+   now and consult all threads.
+
+   We then answer you from the (re)fresh(ed) ThreadTimes.
+*/
+static uint64_t doGetThreadCPUMsec(int n)
+{
+  static std::mutex s_mut;
+  static time_t last = 0;
+  static ThreadTimes tt;
+
+  std::lock_guard<std::mutex> l(s_mut);
+  if(last != time(nullptr)) {
+   tt = broadcastAccFunction<ThreadTimes>(pleaseGetThreadCPUMsec);
+   last = time(nullptr);
+  }
+
+  return tt.times.at(n);
+}
+
 static uint64_t calculateUptime()
 {
   return time(nullptr) - g_stats.startupTime;
@@ -1053,6 +1124,9 @@ void registerAllStats()
   addGetStat("user-msec", getUserTimeMsec);
   addGetStat("sys-msec", getSysTimeMsec);
 
+  for(unsigned int n=0; n < g_numThreads; ++n)
+    addGetStat("cpu-msec-thread-"+std::to_string(n), boost::bind(&doGetThreadCPUMsec, n));
+
 #ifdef MALLOC_TRACE
   addGetStat("memory-allocs", boost::bind(&MallocTracer::getAllocs, g_mtracer, string()));
   addGetStat("memory-alloc-flux", boost::bind(&MallocTracer::getAllocFlux, g_mtracer, string()));
@@ -1073,6 +1147,8 @@ void registerAllStats()
   addGetStat("policy-result-truncate", &g_stats.policyResults[DNSFilterEngine::PolicyKind::Truncate]);
   addGetStat("policy-result-custom", &g_stats.policyResults[DNSFilterEngine::PolicyKind::Custom]);
 
+  addGetStat("rebalanced-queries", &g_stats.rebalancedQueries);
+
   /* make sure that the ECS stats are properly initialized */
   SyncRes::clearECSStats();
   for (size_t idx = 0; idx < SyncRes::s_ecsResponsesBySubnetSize4.size(); idx++) {
@@ -1362,6 +1438,7 @@ string RecursorControlParser::getAnswer(const string& question, RecursorControlP
 "reload-lua-script [filename]     (re)load Lua script\n"
 "reload-lua-config [filename]     (re)load Lua configuration file\n"
 "reload-zones                     reload all auth and forward zones\n"
+"set-ecs-minimum-ttl value        set ecs-minimum-ttl-override\n"
 "set-max-cache-entries value      set new maximum cache size\n"
 "set-max-packetcache-entries val  set new maximum packet cache size\n"      
 "set-minimum-ttl value            set minimum-ttl-override\n"
@@ -1532,6 +1609,10 @@ string RecursorControlParser::getAnswer(const string& question, RecursorControlP
     return reloadAuthAndForwards();
   }
 
+  if(cmd=="set-ecs-minimum-ttl") {
+    return setMinimumECSTTL(begin, end);
+  }
+
   if(cmd=="set-max-cache-entries") {
     return setMaxCacheEntries(begin, end);
   }
index 92c8ae5a27ed0b0eb3ae11fd64c662c64771f1e7..2d7052b1b1c15e1c9413189603e5f71335a162e1 100644 (file)
@@ -168,6 +168,7 @@ pdns_recursor_SOURCES = \
        ueberbackend.hh \
        unix_utility.cc \
        utility.hh \
+       uuid-utils.hh uuid-utils.cc \
        validate.cc validate.hh validate-recursor.cc validate-recursor.hh \
        version.cc version.hh \
        webserver.cc webserver.hh \
@@ -235,6 +236,7 @@ testrunner_SOURCES = \
        nsecrecords.cc \
        pdnsexception.hh \
        opensslsigners.cc opensslsigners.hh \
+       pollmplexer.cc \
        protobuf.cc protobuf.hh \
        qtype.cc qtype.hh \
        rcpgenerator.cc \
@@ -263,6 +265,7 @@ testrunner_SOURCES = \
        test-ixfr_cc.cc \
        test-misc_hh.cc \
        test-mtasker.cc \
+       test-mplexer.cc \
        test-negcache_cc.cc \
        test-packetcache_hh.cc \
        test-rcpgenerator_cc.cc \
@@ -335,16 +338,21 @@ endif
 
 if HAVE_FREEBSD
 pdns_recursor_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
 endif
 
 if HAVE_LINUX
 pdns_recursor_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
 endif
 
 if HAVE_SOLARIS
 pdns_recursor_SOURCES += \
        devpollmplexer.cc \
        portsmplexer.cc
+testrunner_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
 endif
 
 if HAVE_PROTOBUF
@@ -362,10 +370,6 @@ testrunner_LDADD += $(PROTOBUF_LIBS)
 testrunner$(OBJEXT): dnsmessage.pb.cc
 
 endif
-
-pdns_recursor_SOURCES += \
-       uuid-utils.hh uuid-utils.cc
-
 endif
 
 rec_control_SOURCES = \
index cc3368cce7dae6fc707066e7c7b9b24e529e843b..49fc8bcf50225e3525c5a4e1ef8ee7db402b0dc3 100644 (file)
@@ -14,8 +14,8 @@ Starting with version 4.0.0, the PowerDNS recursor uses autotools and compiling
 make
 ```
 
-As for dependencies, Boost (http://boost.org/) and OpenSSL (https://openssl.org/)
-are required.
+As for dependencies, Boost (http://boost.org/), OpenSSL (https://openssl.org/),
+and Lua (https://www.lua.org/) are required.
 
 On most modern UNIX distributions, you can simply install 'boost' or
 'boost-dev' or 'boost-devel'. Otherwise, just download boost, and point the
@@ -57,7 +57,7 @@ Homebrew. You need to tell configure where to find OpenSSL, too.
 
 ```sh
 brew install boost lua pkg-config ragel openssl
-./configure --with-modules="" --with-lua PKG_CONFIG_PATH=/usr/local/opt/openssl/lib/pkgconfig
+./configure --with-lua PKG_CONFIG_PATH=/usr/local/opt/openssl/lib/pkgconfig
 make -j4
 ```
 
index 2af4621d08d6daefe3406205d5f5e20e586246ba..6b9fac22e1bd65cfa66c5c1d0431a759301e194a 100644 (file)
@@ -817,6 +817,14 @@ specialMemoryUsage OBJECT-TYPE
         "Memory usage (more precise bbut expensive to retrieve)"
     ::= { stats 98 }
 
+rebalancedQueries OBJECT-TYPE
+    SYNTAX Counter64
+    MAX-ACCESS read-only
+    STATUS current
+    DESCRIPTION
+        "Number of queries re-distributed because the first selected worker thread was above the target load"
+    ::= { stats 99 }
+
 ---
 --- Traps / Notifications
 ---
@@ -957,6 +965,8 @@ recGroup OBJECT-GROUP
         dnssecAuthenticDataQueries,
         dnssecCheckDisabledQueries,
         variableResponses,
+        specialMemoryUsage,
+        rebalancedQueries,
         trapReason
     }
     STATUS current
index 70aeaa2c62db3ff52b0cd5103b9d9735afd049cc..fbad126428bc35290896463ca950f6146df323de 100644 (file)
@@ -82,16 +82,9 @@ AC_DEFUN([PDNS_SELECT_CONTEXT_IMPL], [
 
 PDNS_CHECK_CLOCK_GETTIME
 
-boost_required_version=1.35
-
 PDNS_WITH_PROTOBUF
-AS_IF([test "x$PROTOBUF_LIBS" != "x" -a x"$PROTOC" != "x"],
-  # The protobuf code needs boost::uuid, which is available from 1.42 onward
-  [AC_MSG_WARN([Bumping minimal Boost requirement to 1.42. To keep the requirement at 1.35, disable protobuf support])
-  boost_required_version=1.42]
-)
 
-BOOST_REQUIRE([$boost_required_version])
+BOOST_REQUIRE([1.42])
 
 # Check against flat_set header that requires boost >= 1.48
 BOOST_FIND_HEADER([boost/container/flat_set.hpp], [AC_MSG_NOTICE([boost::container::flat_set not available, will fallback to std::set])])
index a7843705a326ecd2e0b613e5c31dc786e685f008..6a13b76b50abff67f11e851b5625086ce38cb320 100644 (file)
@@ -1,11 +1,42 @@
 Changelogs for 4.1.x
 ====================
 
+.. changelog::
+  :version: 4.1.12
+  :released: 2nd of April 2019
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7495
+    :tickets: 7494
+
+    Correctly interpret an empty AXFR response to an IXFR query.
+
+  .. change::
+    :tags: Improvements, Internals
+    :pullreq: 7647
+
+    Provide CPU usage statistics per thread (worker & distributor).
+
+  .. change::
+    :tags: Improvements, Internals, Performance
+    :pullreq: 7634
+    :tickets: 7507
+
+    Use a bounded load-balancing algo to distribute queries.
+
+  .. change::
+    :tags: Improvements, Internals
+    :pullreq: 7651
+    :tickets: 7631, 7572
+
+    Implement a configurable ECS cache limit so responses with an ECS scope more specific than a certain threshold and a TTL smaller than a specific threshold are not inserted into the records cache at all.
+
 .. changelog::
   :version: 4.1.11
   :released: 1st of February 2019
 
-  Since Spectre/Meltdown, system calls have become more expensive.  This made exporting a very high number of protobuf messages costly, which is addressed in this release by reducing the number of sycalls per message.
+  Since Spectre/Meltdown, system calls have become more expensive.  This made exporting a very high number of protobuf messages costly, which is addressed in this release by reducing the number of syscalls per message.
 
   .. change::
     :tags: Improvements
index d6b8dfe804981c2d7b531979fce9f898830cb63f..8fa4b018cdaba19ce3a1752361ba157cc67c16ed 100644 (file)
@@ -196,6 +196,9 @@ set-dnssec-log-bogus *SETTING*
     DNSSEC validation failures and to 'no' or 'off' to disable logging these
     failures.
 
+set-ecs-minimum-ttl *NUM*
+    Set ecs-minimum-ttl-override to *NUM*.
+
 set-max-cache-entries *NUM*
     Change the maximum number of entries in the DNS cache.  If reduced, the
     cache size will start shrinking to this number as part of the normal
index 3fb0ee71ed17902cf10b953ee1804e1959a07a20..c04ebbc01f0ce91ea9975b364c5feb2b03d3a6b0 100644 (file)
@@ -185,6 +185,10 @@ concurrent-queries
 ^^^^^^^^^^^^^^^^^^
 shows the number of MThreads currently   running
 
+cpu-msec-thread-n
+^^^^^^^^^^^^^^^^^
+shows the number of milliseconds spent in thread n. Available since 4.1.12.
+
 dlg-only-drops
 ^^^^^^^^^^^^^^
 number of records dropped because of :ref:`setting-delegation-only` setting
@@ -399,6 +403,12 @@ questions
 ^^^^^^^^^
 counts all end-user initiated queries with the RD bit   set
 
+rebalanced-queries
+^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.1.12
+
+number of queries balanced to a different worker thread because the first selected one was above the target load configured with 'distribution-load-factor'
+
 resource-limits
 ^^^^^^^^^^^^^^^
 counts number of queries that could not be   performed because of resource limits
index c74a4ca548d40457edc743dc98b6376805fe86a8..78286bb97be37e2d03549f45c0203d12970d41ff 100644 (file)
@@ -278,6 +278,25 @@ Do not log to syslog, only to stdout.
 Use this setting when running inside a supervisor that handles logging (like systemd).
 **Note**: do not use this setting in combination with `daemon`_ as all logging will disappear.
 
+.. _setting-distribution-load-factor:
+
+``distribution-load-factor``
+----------------------------
+.. versionadded:: 4.1.12
+
+-  Double
+-  Default: 0.0
+
+If `pdns-distributes-queries`_ is set and this setting is set to another value
+than 0, the distributor thread will use a bounded load-balancing algorithm while
+distributing queries to worker threads, making sure that no thread is assigned
+more queries than distribution-load-factor times the average number of queries
+currently processed by all the workers.
+For example, with a value of 1.25, no server should get more than 125 % of the
+average load. This helps making sure that all the workers have roughly the same
+share of queries, even if the incoming traffic is very skewed, with a larger
+number of requests asking for the same qname.
+
 .. _setting-distributor-threads:
 
 ``distributor-threads``
@@ -377,6 +396,18 @@ This defaults to not using the requestor address inside RFC1918 and similar "pri
 
 Number of bits of client IPv4 address to pass when sending EDNS Client Subnet address information.
 
+.. _setting-ecs-ipv4-cache-bits:
+
+``ecs-ipv4-cache-bits``
+-----------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 24
+
+Maximum number of bits of client IPv4 address used by the authoritative server (as indicated by the EDNS Client Subnet scope in the answer) for an answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-cache-limit-ttl``.
+That is, only if both the limits apply, the record will not be cached.
+
 .. _setting-ecs-ipv6-bits:
 
 ``ecs-ipv6-bits``
@@ -388,6 +419,41 @@ Number of bits of client IPv4 address to pass when sending EDNS Client Subnet ad
 
 Number of bits of client IPv6 address to pass when sending EDNS Client Subnet address information.
 
+.. _setting-ecs-ipv6-cache-bits:
+
+``ecs-ipv6-cache-bits``
+-----------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 56
+
+Maximum number of bits of client IPv6 address used by the authoritative server (as indicated by the EDNS Client Subnet scope in the answer) for an answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-cache-limit-ttl``.
+That is, only if both the limits apply, the record will not be cached.
+
+.. _setting-ecs-minimum-ttl-override:
+
+``ecs-minimum-ttl-override``
+----------------------------
+-  Integer
+-  Default: 0 (disabled)
+
+This setting artificially raises the TTLs of records in the ANSWER section of ECS-specific answers to be at least this long.
+While this is a gross hack, and violates RFCs, under conditions of DoS, it may enable you to continue serving your customers.
+Can be set at runtime using ``rec_control set-ecs-minimum-ttl 3600``.
+
+.. _setting-ecs-cache-limit-ttl:
+
+``ecs-cache-limit-ttl``
+-----------------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 0 (disabled)
+
+The minimum TTL for an ECS-specific answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-ipv4-cache-bits`` or ``ecs-ipv6-cache-bits``.
+That is, only if both the limits apply, the record will not be cached.
+
 .. _setting-ecs-scope-zero-address:
 
 ``ecs-scope-zero-address``
@@ -1558,6 +1624,47 @@ IP address for the webserver to listen on.
 
 These subnets are allowed to access the webserver.
 
+.. _setting-webserver-loglevel:
+
+``webserver-loglevel``
+----------------------
+.. versionadded:: 4.2.0
+
+-  String, one of "none", "normal", "detailed"
+
+The amount of logging the webserver must do. "none" means no useful webserver information will be logged.
+When set to "normal", the webserver will log a line per request that should be familiar::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+When set to "detailed", all information about the request and response are logged::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Request Details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-encoding: gzip, deflate
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-language: en-US,en;q=0.5
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   connection: keep-alive
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   dnt: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   host: 127.0.0.1:8081
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   upgrade-insecure-requests: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   user-agent: Mozilla/5.0 (X11; Linux x86_64; rv:64.0) Gecko/20100101 Firefox/64.0
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  No body
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Response details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Connection: close
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Length: 49
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Type: text/html; charset=utf-8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Server: PowerDNS/0.0.15896.0.gaba8bab3ab
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Full body: 
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   <!html><title>Not Found</title><h1>Not Found</h1>
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+The value between the hooks is a UUID that is generated for each request. This can be used to find all lines related to a single request.
+
+.. note::
+  The webserver logs these line on the NOTICE level. The :ref:`settings-loglevel` seting must be 5 or higher for these lines to end up in the log.
+
 .. _setting-webserver-password:
 
 ``webserver-password``
index d62a64c194c8cbd195db0c83d905585d048c9f2e..a49068f385fc43f02ce32239cf127d698e7d0f2f 100644 (file)
@@ -126,7 +126,9 @@ static void checkECSOptionValidity(const std::string& sourceStr, uint8_t sourceM
   uint8_t sourceBytes = ((sourceMask - 1) >> 3) + 1;
   BOOST_REQUIRE_EQUAL(ecsOptionStr.size(), (ecsHeaderSize + sourceBytes));
   /* family */
-  BOOST_REQUIRE_EQUAL(ntohs(*(reinterpret_cast<const uint16_t*>(&ecsOptionStr.at(0)))), source.isIPv4() ? 1 : 2);
+  uint16_t u;
+  memcpy(&u, ecsOptionStr.c_str(), sizeof(u));
+  BOOST_REQUIRE_EQUAL(ntohs(u), source.isIPv4() ? 1 : 2);
   /* source mask */
   BOOST_REQUIRE_EQUAL(static_cast<uint8_t>(ecsOptionStr.at(2)), sourceMask);
   BOOST_REQUIRE_EQUAL(static_cast<uint8_t>(ecsOptionStr.at(3)), scopeMask);
diff --git a/pdns/recursordist/test-mplexer.cc b/pdns/recursordist/test-mplexer.cc
new file mode 120000 (symlink)
index 0000000..f406267
--- /dev/null
@@ -0,0 +1 @@
+../test-mplexer.cc
\ No newline at end of file
index f66682bacfde1981c5b2be44ab8fa0737a43683d..4b5e6297225b103aa743da7232ea85819c7bd0f5 100644 (file)
@@ -130,8 +130,12 @@ static void init(bool debug=false)
   SyncRes::s_doIPv6 = true;
   SyncRes::s_ecsipv4limit = 24;
   SyncRes::s_ecsipv6limit = 56;
+  SyncRes::s_ecsipv4cachelimit = 24;
+  SyncRes::s_ecsipv6cachelimit = 56;
+  SyncRes::s_ecscachelimitttl = 0;
   SyncRes::s_rootNXTrust = true;
   SyncRes::s_minimumTTL = 0;
+  SyncRes::s_minimumECSTTL = 0;
   SyncRes::s_serverID = "PowerDNS Unit Tests Server ID";
   SyncRes::clearEDNSLocalSubnets();
   SyncRes::addEDNSLocalSubnet("0.0.0.0/0");
@@ -2069,6 +2073,8 @@ BOOST_AUTO_TEST_CASE(test_skip_negcache_for_variable_response) {
         addRecordToLW(res, "powerdns.com.", QType::NS, "pdns-public-ns1.powerdns.com.", DNSResourceRecord::AUTHORITY, 172800);
         addRecordToLW(res, "pdns-public-ns1.powerdns.com.", QType::A, "192.0.2.1", DNSResourceRecord::ADDITIONAL, 3600);
 
+        srcmask = boost::none;
+
         return 1;
       } else if (ip == ComboAddress("192.0.2.1:53")) {
         if (domain == target) {
@@ -2098,6 +2104,204 @@ BOOST_AUTO_TEST_CASE(test_skip_negcache_for_variable_response) {
   BOOST_CHECK_EQUAL(SyncRes::getNegCacheSize(), 0);
 }
 
+BOOST_AUTO_TEST_CASE(test_ecs_cache_limit_allowed) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("www.powerdns.com.");
+
+  SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::s_ecsipv4cachelimit = 24;
+
+  sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 1);
+
+  /* should have been cached */
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_limit_no_ttl_limit_allowed) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("www.powerdns.com.");
+
+  SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::s_ecsipv4cachelimit = 16;
+
+  sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 1);
+
+  /* should have been cached because /24 is more specific than /16 but TTL limit is nof effective */
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_allowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 30;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have been cached */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_and_scope_allowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 100;
+    SyncRes::s_ecsipv4cachelimit = 24;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have been cached */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_notallowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 100;
+    SyncRes::s_ecsipv4cachelimit = 16;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have NOT been cached because TTL of 60 is too small and /24 is more specific than /16 */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_LT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 0);
+}
+
+
 BOOST_AUTO_TEST_CASE(test_ns_speed) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
@@ -2346,6 +2550,75 @@ BOOST_AUTO_TEST_CASE(test_cache_min_max_ttl) {
   BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_maxcachettl);
 }
 
+BOOST_AUTO_TEST_CASE(test_cache_min_max_ecs_ttl) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("cacheecsttl.powerdns.com.");
+  const ComboAddress ns("192.0.2.1:53");
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::addEDNSDomain(target);
+
+  sr->setAsyncCallback([target,ns](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      if (isRootServer(ip)) {
+
+        setLWResult(res, 0, false, false, true);
+        addRecordToLW(res, domain, QType::NS, "a.gtld-servers.net.", DNSResourceRecord::AUTHORITY, 172800);
+        addRecordToLW(res, "a.gtld-servers.net.", QType::A, ns.toString(), DNSResourceRecord::ADDITIONAL, 20);
+        srcmask = boost::none;
+
+        return 1;
+      } else if (ip == ns) {
+
+        setLWResult(res, 0, true, false, false);
+        addRecordToLW(res, domain, QType::A, "192.0.2.2", DNSResourceRecord::ANSWER, 10);
+
+        return 1;
+      }
+
+      return 0;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  SyncRes::s_minimumTTL = 60;
+  SyncRes::s_minimumECSTTL = 120;
+  SyncRes::s_maxcachettl = 3600;
+
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_REQUIRE_EQUAL(ret.size(), 1);
+  BOOST_CHECK_EQUAL(ret[0].d_ttl, SyncRes::s_minimumECSTTL);
+
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_EQUAL((cached[0].d_ttl - now), SyncRes::s_minimumECSTTL);
+
+  cached.clear();
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::NS), false, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_maxcachettl);
+
+  cached.clear();
+  BOOST_REQUIRE_GT(t_RC->get(now, DNSName("a.gtld-servers.net."), QType(QType::A), false, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_minimumTTL);
+}
+
 BOOST_AUTO_TEST_CASE(test_cache_expired_ttl) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
index 61b4c6b0087f43990b45f14690ec7e67da161365..ac4f75a4c82d38c7be22846046449842a5ed130a 100644 (file)
@@ -24,9 +24,9 @@ void ResponseStats::submitResponse(DNSPacket &p, bool udpOrTCP) {
 
   if(p.d.aa) {
     if (p.d.rcode==RCode::NXDomain)
-      S.ringAccount("nxdomain-queries",p.qdomain.toLogString()+"/"+p.qtype.getName());
-  } else if (p.isEmpty()) {
-    S.ringAccount("unauth-queries",p.qdomain.toLogString()+"/"+p.qtype.getName());
+      S.ringAccount("nxdomain-queries", p.qdomain, p.qtype);
+  } else if (p.d.rcode == RCode::Refused) {
+    S.ringAccount("unauth-queries", p.qdomain, p.qtype);
     S.ringAccount("remotes-unauth",p.d_remote);
   }
 
index 1a504ef30232bb6826d8606fbf7bba7075f36688..3f923cb1bd77391750e5a6736e259778566cb3f4 100644 (file)
@@ -791,6 +791,34 @@ struct NOPTest
 
 };
 
+struct StatRingDNSNameQTypeToStringTest
+{
+  explicit StatRingDNSNameQTypeToStringTest(const DNSName &name, const QType type) : d_name(name), d_type(type) {}
+
+  string getName() const { return "StatRing test with DNSName and QType to string"; }
+
+  void operator()() const {
+    S.ringAccount("testring", d_name.toLogString()+"/"+d_type.getName());
+  };
+
+  DNSName d_name;
+  QType d_type;
+};
+
+struct StatRingDNSNameQTypeTest
+{
+  explicit StatRingDNSNameQTypeTest(const DNSName &name, const QType type) : d_name(name), d_type(type) {}
+
+  string getName() const { return "StatRing test with DNSName and QType"; }
+
+  void operator()() const {
+    S.ringAccount("testringdnsname", d_name, d_type);
+  };
+
+  DNSName d_name;
+  QType d_type;
+};
+
 
 
 int main(int argc, char** argv)
@@ -876,6 +904,16 @@ try
   doRun(DNSNameParseTest());
   doRun(DNSNameRootTest());
 
+#ifndef RECURSOR
+  S.doRings();
+
+  S.declareRing("testring", "Just some ring where we'll account things");
+  doRun(StatRingDNSNameQTypeToStringTest(DNSName("example.com"), QType(1)));
+
+  S.declareDNSNameQTypeRing("testringdnsname", "Just some ring where we'll account things");
+  doRun(StatRingDNSNameQTypeTest(DNSName("example.com"), QType(1)));
+#endif
+
   cerr<<"Total runs: " << g_totalRuns<<endl;
 
 }
index 403611eff9b18b9204f7d50fbcde32eeb4e56bae..922315a5e21bc3aeeb2c1c4bd1c321b0d29030b1 100644 (file)
@@ -47,43 +47,42 @@ typedef int ProtocolType; //!< Supported protocol types
 //! Representation of a Socket and many of the Berkeley functions available
 class Socket : public boost::noncopyable
 {
-  Socket(int fd)
+  Socket(int fd): d_socket(fd)
   {
-    d_socket = fd;
-    d_buflen=4096;
-    d_buffer=new char[d_buflen];
   }
 
 public:
   //! Construct a socket of specified address family and socket type.
   Socket(int af, int st, ProtocolType pt=0)
   {
-    if((d_socket=socket(af,st, pt))<0)
+    if((d_socket=socket(af, st, pt))<0)
       throw NetworkError(strerror(errno));
     setCloseOnExec(d_socket);
+  }
 
-    d_buflen=4096;
-    d_buffer=new char[d_buflen];
+  Socket(Socket&& rhs): d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
+  {
+    rhs.d_socket = -1;
   }
 
   ~Socket()
   {
     try {
-      closesocket(d_socket);
+      if (d_socket != -1) {
+        closesocket(d_socket);
+      }
     }
     catch(const PDNSException& e) {
     }
-
-    delete[] d_buffer;
   }
 
   //! If the socket is capable of doing so, this function will wait for a connection
-  Socket *accept()
+  std::unique_ptr<Socket> accept()
   {
     struct sockaddr_in remote;
     socklen_t remlen=sizeof(remote);
     memset(&remote, 0, sizeof(remote));
-    int s=::accept(d_socket,(sockaddr *)&remote, &remlen);
+    int s=::accept(d_socket, reinterpret_cast<sockaddr *>(&remote), &remlen);
     if(s<0) {
       if(errno==EAGAIN)
         return nullptr;
@@ -91,21 +90,21 @@ public:
       throw NetworkError("Accepting a connection: "+string(strerror(errno)));
     }
 
-    return new Socket(s);
+    return std::unique_ptr<Socket>(new Socket(s));
   }
 
   //! Get remote address
   bool getRemote(ComboAddress &remote) {
     socklen_t remotelen=sizeof(remote);
-    return (getpeername(d_socket, (struct sockaddr *)&remote, &remotelen) >= 0);
+    return (getpeername(d_socket, reinterpret_cast<struct sockaddr *>(&remote), &remotelen) >= 0);
   }
 
   //! Check remote address against netmaskgroup ng
-  bool acl(NetmaskGroup &ng)
+  bool acl(const NetmaskGroup &ng)
   {
     ComboAddress remote;
     if (getRemote(remote))
-      return ng.match((ComboAddress *) &remote);
+      return ng.match(remote);
 
     return false;
   }
@@ -115,6 +114,7 @@ public:
   {
     ::setNonBlocking(d_socket);
   }
+
   //! Set the socket to blocking
   void setBlocking()
   {
@@ -125,35 +125,22 @@ public:
   {
     try {
       ::setReuseAddr(d_socket);
-    } catch (PDNSException &e) {
+    } catch (const PDNSException &e) {
       throw NetworkError(e.reason);
     }
   }
 
   //! Bind the socket to a specified endpoint
-  void bind(const ComboAddress &local)
+  void bind(const ComboAddress &local, bool reuseaddr=true)
   {
     int tmp=1;
-    if(setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR,(char*)&tmp,sizeof tmp)<0)
+    if(reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp)<0)
       throw NetworkError(string("Setsockopt failed: ")+strerror(errno));
 
-    if(::bind(d_socket,(struct sockaddr *)&local, local.getSocklen())<0)
+    if(::bind(d_socket, reinterpret_cast<const struct sockaddr *>(&local), local.getSocklen())<0)
       throw NetworkError("While binding: "+string(strerror(errno)));
   }
 
-#if 0
-  //! Bind the socket to a specified endpoint
-  void bind(const ComboAddress &ep)
-  {
-    ComboAddress local;
-    memset(reinterpret_cast<char *>(&local),0,sizeof(local));
-    local.sin_family=d_family;
-    local.sin_addr.s_addr=ep.address.byte;
-    local.sin_port=htons(ep.port);
-    
-    bind(local);
-  }
-#endif
   //! Connect the socket to a specified endpoint
   void connect(const ComboAddress &ep, int timeout=0)
   {
@@ -167,20 +154,22 @@ public:
       \param ep Will be filled with the origin of the datagram */
   void recvFrom(string &dgram, ComboAddress &ep)
   {
-    socklen_t remlen=sizeof(ep);
+    socklen_t remlen = sizeof(ep);
     ssize_t bytes;
-    if((bytes=recvfrom(d_socket, d_buffer, d_buflen, 0, (sockaddr *)&ep , &remlen)) <0)
+    d_buffer.resize(s_buflen);
+    if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&ep) , &remlen)) <0)
       throw NetworkError("After recvfrom: "+string(strerror(errno)));
     
-    dgram.assign(d_buffer,bytes);
+    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
   }
 
   bool recvFromAsync(string &dgram, ComboAddress &ep)
   {
     struct sockaddr_in remote;
-    socklen_t remlen=sizeof(remote);
+    socklen_t remlen = sizeof(remote);
     ssize_t bytes;
-    if((bytes=recvfrom(d_socket, d_buffer, d_buflen, 0, (sockaddr *)&remote, &remlen))<0) {
+    d_buffer.resize(s_buflen);
+    if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&remote), &remlen))<0) {
       if(errno!=EAGAIN) {
         throw NetworkError("After async recvfrom: "+string(strerror(errno)));
       }
@@ -188,7 +177,7 @@ public:
         return false;
       }
     }
-    dgram.assign(d_buffer,bytes);
+    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
     return true;
   }
 
@@ -196,7 +185,7 @@ public:
   //! For datagram sockets, send a datagram to a destination
   void sendTo(const char* msg, size_t len, const ComboAddress &ep)
   {
-    if(sendto(d_socket, msg, len, 0, (sockaddr *)&ep, ep.getSocklen())<0)
+    if(sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr *>(&ep), ep.getSocklen())<0)
       throw NetworkError("After sendto: "+string(strerror(errno)));
   }
 
@@ -233,9 +222,9 @@ public:
         throw NetworkError("Writing to a socket: "+string(strerror(errno)));
       if(!res)
         throw NetworkError("EOF on socket");
-      toWrite-=(size_t)res;
-      ptr+=(size_t)res;
-    }while(toWrite);
+      toWrite -= static_cast<size_t>(res);
+      ptr += static_cast<size_t>(res);
+    } while(toWrite);
 
   }
 
@@ -275,7 +264,7 @@ public:
   void writenWithTimeout(const void *buffer, size_t n, int timeout)
   {
     size_t bytes=n;
-    const char *ptr = (char*)buffer;
+    const char *ptr = reinterpret_cast<const char*>(buffer);
     ssize_t ret;
     while(bytes) {
       ret=::write(d_socket, ptr, bytes);
@@ -295,8 +284,8 @@ public:
         throw NetworkError("Did not fulfill TCP write due to EOF");
       }
 
-      ptr += (size_t) ret;
-      bytes -= (size_t) ret;
+      ptr += static_cast<size_t>(ret);
+      bytes -= static_cast<size_t>(ret);
     }
   }
 
@@ -325,19 +314,20 @@ public:
   //! Reads a block of data from the socket to a string
   void read(string &data)
   {
-    ssize_t res=::recv(d_socket,d_buffer,d_buflen,0);
+    d_buffer.resize(s_buflen);
+    ssize_t res=::recv(d_socket, &d_buffer[0], s_buflen, 0);
     if(res<0) 
       throw NetworkError("Reading from a socket: "+string(strerror(errno)));
-    data.assign(d_buffer,res);
+    data.assign(d_buffer, 0, static_cast<size_t>(res));
   }
 
   //! Reads a block of data from the socket to a block of memory
   size_t read(char *buffer, size_t bytes)
   {
-    ssize_t res=::recv(d_socket,buffer,bytes,0);
+    ssize_t res=::recv(d_socket, buffer, bytes, 0);
     if(res<0) 
       throw NetworkError("Reading from a socket: "+string(strerror(errno)));
-    return (size_t) res;
+    return static_cast<size_t>(res);
   }
 
   ssize_t readWithTimeout(char* buffer, size_t n, int timeout)
@@ -366,9 +356,9 @@ public:
   }
   
 private:
-  char *d_buffer;
+  static const size_t s_buflen{4096};
+  std::string d_buffer;
   int d_socket;
-  size_t d_buflen;
 };
 
 
index a4b0436f01113d51658a646d09e7be949d7d6a46..c4f46a5b1f8fc70a62cb183f2d7fa7097c34c86d 100644 (file)
@@ -218,7 +218,7 @@ vector<pair<T, unsigned int> >StatRing<T,Comp>::get() const
 
 void StatBag::declareRing(const string &name, const string &help, unsigned int size)
 {
-  d_rings[name]=StatRing<string>(size);
+  d_rings[name]=StatRing<string, CIStringCompare>(size);
   d_rings[name].setHelp(help);
 }
 
@@ -228,21 +228,33 @@ void StatBag::declareComboRing(const string &name, const string &help, unsigned
   d_comborings[name].setHelp(help);
 }
 
+void StatBag::declareDNSNameQTypeRing(const string &name, const string &help, unsigned int size)
+{
+  d_dnsnameqtyperings[name] = StatRing<std::tuple<DNSName, QType> >(size);
+  d_dnsnameqtyperings[name].setHelp(help);
+}
+
 
 vector<pair<string, unsigned int> > StatBag::getRing(const string &name)
 {
-  if(d_rings.count(name))
+  if(d_rings.count(name)) {
     return d_rings[name].get();
-  else {
+  }
+  vector<pair<string, unsigned int> > ret;
+
+  if (d_comborings.count(name)) {
     typedef pair<SComboAddress, unsigned int> stor_t;
     vector<stor_t> raw =d_comborings[name].get();
-    vector<pair<string, unsigned int> > ret;
     for(const stor_t& stor :  raw) {
       ret.push_back(make_pair(stor.first.ca.toString(), stor.second));
     }
-    return ret;
+  } else if(d_dnsnameqtyperings.count(name)) {
+    auto raw = d_dnsnameqtyperings[name].get();
+    for (auto const &e : raw) {
+      ret.push_back(make_pair(std::get<0>(e.first).toLogString() + "/" + std::get<1>(e.first).getName(), e.second));
+    }
   }
-    
+  return ret;
 }
 
 template<typename T, typename Comp>
@@ -256,16 +268,20 @@ void StatBag::resetRing(const string &name)
 {
   if(d_rings.count(name))
     d_rings[name].reset();
-  else
+  if(d_comborings.count(name))
     d_comborings[name].reset();
+  if(d_dnsnameqtyperings.count(name))
+    d_dnsnameqtyperings[name].reset();
 }
 
 void StatBag::resizeRing(const string &name, unsigned int newsize)
 {
   if(d_rings.count(name))
     d_rings[name].resize(newsize);
-  else
+  if(d_comborings.count(name))
     d_comborings[name].resize(newsize);
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].resize(newsize);
 }
 
 
@@ -273,33 +289,42 @@ unsigned int StatBag::getRingSize(const string &name)
 {
   if(d_rings.count(name))
     return d_rings[name].getSize();
-  else
+  if(d_comborings.count(name))
     return d_comborings[name].getSize();
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].getSize();
+  return 0;
 }
 
 string StatBag::getRingTitle(const string &name)
 {
   if(d_rings.count(name))
     return d_rings[name].getHelp();
-  else 
+  if(d_comborings.count(name))
     return d_comborings[name].getHelp();
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].getHelp();
+  return "";
 }
 
 vector<string>StatBag::listRings()
 {
   vector<string> ret;
-  for(map<string,StatRing<string> >::const_iterator i=d_rings.begin();i!=d_rings.end();++i)
+  for(auto i=d_rings.begin();i!=d_rings.end();++i)
     ret.push_back(i->first);
-  for(map<string,StatRing<SComboAddress> >::const_iterator i=d_comborings.begin();i!=d_comborings.end();++i)
+  for(auto i=d_comborings.begin();i!=d_comborings.end();++i)
     ret.push_back(i->first);
+  for(const auto &i : d_dnsnameqtyperings)
+    ret.push_back(i.first);
 
   return ret;
 }
 
 bool StatBag::ringExists(const string &name)
 {
-  return d_rings.count(name) || d_comborings.count(name);
+  return d_rings.count(name) || d_comborings.count(name) || d_dnsnameqtyperings.count(name);
 }
 
-template class StatRing<std::string>;
+template class StatRing<std::string, CIStringCompare>;
 template class StatRing<SComboAddress>;
+template class StatRing<std::tuple<DNSName, QType> >;
index b92c6a4371030dddbcde49794dd4b4fd60131808..2c8ca2850dfd4d7ed7caa23cda8ac7fc8b3aa385 100644 (file)
@@ -64,8 +64,9 @@ class StatBag
 {
   map<string, AtomicCounter *> d_stats;
   map<string, string> d_keyDescrips;
-  map<string,StatRing<string> >d_rings;
+  map<string,StatRing<string, CIStringCompare> >d_rings;
   map<string,StatRing<SComboAddress> >d_comborings;
+  map<string,StatRing<std::tuple<DNSName, QType> > >d_dnsnameqtyperings;
   typedef boost::function<uint64_t(const std::string&)> func_t;
   typedef map<string, func_t> funcstats_t;
   funcstats_t d_funcstats;
@@ -79,6 +80,7 @@ public:
 
   void declareRing(const string &name, const string &title, unsigned int size=10000);
   void declareComboRing(const string &name, const string &help, unsigned int size=10000);
+  void declareDNSNameQTypeRing(const string &name, const string &help, unsigned int size=10000);
   vector<pair<string, unsigned int> >getRing(const string &name);
   string getRingTitle(const string &name);
   void ringAccount(const char* name, const string &item)
@@ -98,6 +100,14 @@ public:
       d_comborings[name].account(item);
     }
   }
+  void ringAccount(const char* name, const DNSName &dnsname, const QType &qtype)
+  {
+    if(d_doRings) {
+      if(!d_dnsnameqtyperings.count(name))
+       throw runtime_error("Attempting to account to non-existent dnsname+qtype ring '"+std::string(name)+"'");
+      d_dnsnameqtyperings[name].account(std::make_tuple(dnsname, qtype));
+    }
+  }
 
   void doRings()
   {
index 39c20c0fc986302c559376842c25e90d9b845a86..ac8d78cb3035febf0d8446518c267a55bde5a4d2 100644 (file)
@@ -60,6 +60,10 @@ public:
   Stat print(unsigned int depth=0, Stat newstat=Stat(), bool silent=false) const;
   typedef boost::function<void(const StatNode*, const Stat& selfstat, const Stat& childstat)> visitor_t;
   void visit(visitor_t visitor, Stat& newstat, unsigned int depth=0) const;
+  bool empty() const
+  {
+    return children.empty() && s.remotes.empty();
+  }
   typedef std::map<std::string,StatNode, CIStringCompare> children_t;
   children_t children;
 
index 933d32c6826f8bf04c9ce1a3e542ddff4f247aa9..1b71d3c190b1ae2b27b3b308415aa292421e0dfe 100644 (file)
@@ -55,10 +55,12 @@ unsigned int SyncRes::s_maxqperq;
 unsigned int SyncRes::s_maxtotusec;
 unsigned int SyncRes::s_maxdepth;
 unsigned int SyncRes::s_minimumTTL;
+unsigned int SyncRes::s_minimumECSTTL;
 unsigned int SyncRes::s_packetcachettl;
 unsigned int SyncRes::s_packetcacheservfailttl;
 unsigned int SyncRes::s_serverdownmaxfails;
 unsigned int SyncRes::s_serverdownthrottletime;
+unsigned int SyncRes::s_ecscachelimitttl;
 std::atomic<uint64_t> SyncRes::s_authzonequeries;
 std::atomic<uint64_t> SyncRes::s_queries;
 std::atomic<uint64_t> SyncRes::s_outgoingtimeouts;
@@ -77,6 +79,9 @@ std::map<uint8_t, std::atomic<uint64_t>> SyncRes::s_ecsResponsesBySubnetSize6;
 
 uint8_t SyncRes::s_ecsipv4limit;
 uint8_t SyncRes::s_ecsipv6limit;
+uint8_t SyncRes::s_ecsipv4cachelimit;
+uint8_t SyncRes::s_ecsipv6cachelimit;
+
 bool SyncRes::s_doIPv6;
 bool SyncRes::s_nopacketcache;
 bool SyncRes::s_rootNXTrust;
@@ -2415,7 +2420,31 @@ RCode::rcodes_ SyncRes::updateCacheFromRecords(unsigned int depth, LWResult& lwr
        - NS, A and AAAA (used for infra queries)
     */
     if (i->first.type != QType::NSEC3 && (i->first.type == QType::DS || i->first.type == QType::NS || i->first.type == QType::A || i->first.type == QType::AAAA || isAA || wasForwardRecurse)) {
-      t_RC->replace(d_now.tv_sec, i->first.name, QType(i->first.type), i->second.records, i->second.signatures, authorityRecs, i->first.type == QType::DS ? true : isAA, i->first.place == DNSResourceRecord::ANSWER ? ednsmask : boost::none, recordState);
+
+      bool doCache = true;
+      if (i->first.place == DNSResourceRecord::ANSWER && ednsmask) {
+        // If ednsmask is relevant, we do not want to cache if the scope prefix length is large and TTL is small
+        if (SyncRes::s_ecscachelimitttl > 0) {
+          bool manyMaskBits = (ednsmask->isIpv4() && ednsmask->getBits() > SyncRes::s_ecsipv4cachelimit) ||
+            (ednsmask->isIpv6() && ednsmask->getBits() > SyncRes::s_ecsipv6cachelimit);
+
+          if (manyMaskBits) {
+            uint32_t minttl = UINT32_MAX;
+            for (const auto &it : i->second.records) {
+              if (it.d_ttl < minttl)
+                minttl = it.d_ttl;
+            }
+            bool ttlIsSmall = minttl < SyncRes::s_ecscachelimitttl + d_now.tv_sec;
+            if (ttlIsSmall) {
+              // Case: many bits and ttlIsSmall
+              doCache = false;
+            }
+          }
+        }
+      }
+      if (doCache) {
+        t_RC->replace(d_now.tv_sec, i->first.name, QType(i->first.type), i->second.records, i->second.signatures, authorityRecs, i->first.type == QType::DS ? true : isAA, i->first.place == DNSResourceRecord::ANSWER ? ednsmask : boost::none, recordState);
+      }
     }
 
     if(i->first.place == DNSResourceRecord::ANSWER && ednsmask)
@@ -2817,6 +2846,16 @@ bool SyncRes::processAnswer(unsigned int depth, LWResult& lwr, const DNSName& qn
     }
   }
 
+  /* if the answer is ECS-specific, a minimum TTL is set for this kind of answers
+     and it's higher than the global minimum TTL */
+  if (ednsmask && s_minimumECSTTL > 0 && (s_minimumTTL == 0 || s_minimumECSTTL > s_minimumTTL)) {
+    for(auto& rec : lwr.d_records) {
+      if (rec.d_place == DNSResourceRecord::ANSWER) {
+        rec.d_ttl = max(rec.d_ttl, s_minimumECSTTL);
+      }
+    }
+  }
+
   bool needWildcardProof = false;
   unsigned int wildcardLabelsCount;
   *rcode = updateCacheFromRecords(depth, lwr, qname, qtype, auth, wasForwarded, ednsmask, state, needWildcardProof, wildcardLabelsCount, sendRDQuery);
index f177f47b5b0e586b07d95310d68c2acd34148246..a28ca22c57a34c0a35d5cc362cc07c11ed217551 100644 (file)
@@ -705,6 +705,7 @@ public:
 
   static string s_serverID;
   static unsigned int s_minimumTTL;
+  static unsigned int s_minimumECSTTL;
   static unsigned int s_maxqperq;
   static unsigned int s_maxtotusec;
   static unsigned int s_maxdepth;
@@ -715,8 +716,11 @@ public:
   static unsigned int s_packetcacheservfailttl;
   static unsigned int s_serverdownmaxfails;
   static unsigned int s_serverdownthrottletime;
+  static unsigned int s_ecscachelimitttl;
   static uint8_t s_ecsipv4limit;
   static uint8_t s_ecsipv6limit;
+  static uint8_t s_ecsipv4cachelimit;
+  static uint8_t s_ecsipv6cachelimit;
   static bool s_doIPv6;
   static bool s_noEDNSPing;
   static bool s_noEDNS;
@@ -956,6 +960,7 @@ struct RecursorStats
   std::atomic<uint64_t> dnssecValidations; // should be the sum of all dnssecResult* stats
   std::map<vState, std::atomic<uint64_t> > dnssecResults;
   std::map<DNSFilterEngine::PolicyKind, std::atomic<uint64_t> > policyResults;
+  std::atomic<uint64_t> rebalancedQueries{0};
 };
 
 //! represents a running TCP/IP client session
@@ -1035,3 +1040,14 @@ void doCarbonDump(void*);
 void primeHints(void);
 
 extern __thread struct timeval g_now;
+
+struct ThreadTimes
+{
+  uint64_t msec;
+  vector<uint64_t> times;
+  ThreadTimes& operator+=(const ThreadTimes& rhs)
+  {
+    times.push_back(rhs.msec);
+    return *this;
+  }
+};
index 0d5bfa514eea42333c21f7fea3c86341bc356068..5c2c90a0532b291c1be88488a8ba8c6d2fc650c8 100644 (file)
@@ -4,12 +4,18 @@
 
 #include "misc.hh"
 
+enum class IOState { Done, NeedRead, NeedWrite };
+
 class TLSConnection
 {
 public:
   virtual ~TLSConnection() { }
+  virtual void doHandshake() = 0;
+  virtual IOState tryHandshake() = 0;
   virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
+  virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
+  virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
   virtual void close() = 0;
 
 protected:
@@ -131,19 +137,15 @@ public:
     return res;
   }
 
-  std::set<int> d_cpus;
   std::vector<std::pair<std::string, std::string>> d_certKeyPairs;
   ComboAddress d_addr;
   std::string d_ciphers;
   std::string d_provider;
-  std::string d_interface;
   std::string d_ticketKeyFile;
 
   size_t d_maxStoredSessions{20480};
   time_t d_ticketsKeyRotationDelay{43200};
-  int d_tcpFastOpenQueueSize{0};
   uint8_t d_numberOfTicketsKeys{5};
-  bool d_reusePort{false};
   bool d_enableTickets{true};
 
 private:
@@ -153,12 +155,14 @@ private:
 class TCPIOHandler
 {
 public:
+
   TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
   {
     if (ctx) {
       d_conn = ctx->getConnection(d_socket, timeout, now);
     }
   }
+
   ~TCPIOHandler()
   {
     if (d_conn) {
@@ -168,6 +172,15 @@ public:
       shutdown(d_socket, SHUT_RDWR);
     }
   }
+
+  IOState tryHandshake()
+  {
+    if (d_conn) {
+      return d_conn->tryHandshake();
+    }
+    return IOState::Done;
+  }
+
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
   {
     if (d_conn) {
@@ -176,6 +189,81 @@ public:
       return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
     }
   }
+
+  /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+     Updates pos everytime a successful read occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
+  {
+    if (buffer.size() < (pos + toRead)) {
+      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
+    }
+
+    if (d_conn) {
+      return d_conn->tryRead(buffer, pos, toRead);
+    }
+
+    size_t got = 0;
+    do {
+      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+      if (res == 0) {
+        throw runtime_error("EOF while reading message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedRead;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      got += static_cast<size_t>(res);
+    }
+    while (got < toRead);
+
+    return IOState::Done;
+  }
+
+  /* Tries to write exactly toWrite bytes from the buffer, starting at position pos.
+     Updates pos everytime a successful write occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
+  {
+    if (d_conn) {
+      return d_conn->tryWrite(buffer, pos, toWrite);
+    }
+
+    size_t sent = 0;
+    do {
+      ssize_t res = ::write(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toWrite - sent);
+      if (res == 0) {
+        throw runtime_error("EOF while sending message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedWrite;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      sent += static_cast<size_t>(res);
+    }
+    while (sent < toWrite);
+
+    return IOState::Done;
+  }
+
   size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
   {
     if (d_conn) {
index 706305c87aa718be343b0308dee1ce78b9383ade..9c0a7b80c11256ca94fd0f56833d20e51f7b64b0 100644 (file)
@@ -1352,7 +1352,7 @@ void TCPNameserver::thread()
 
       int sock=-1;
       for(const pollfd& pfd :  d_prfds) {
-        if(pfd.revents == POLLIN) {
+        if(pfd.revents & POLLIN) {
           sock = pfd.fd;
           remote.sin4.sin_family = AF_INET6;
           addrlen=remote.getSocklen();
index 872592eadbf377c0ace3a900acb9cab799bc0eb3..100dd47109fbf539896e8d769d4cd45b1a4610b2 100644 (file)
@@ -1105,8 +1105,7 @@ static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, co
 {
   dnsheader* dh = reinterpret_cast<dnsheader*>(query.data());
 
-  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime);
-  return dq;
+  return DNSQuestion(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime);
 }
 
 static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, vector<uint8_t>&  query, bool resizeBuffer=true)
@@ -1433,8 +1432,7 @@ BOOST_AUTO_TEST_CASE(test_isEDNSOptionInOpt) {
       return false;
     }
 
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
+    if (optLen < optRecordMinimumSize) {
       return false;
     }
 
diff --git a/pdns/test-ipcrypt_cc.cc b/pdns/test-ipcrypt_cc.cc
new file mode 100644 (file)
index 0000000..a3eae8e
--- /dev/null
@@ -0,0 +1,70 @@
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+#include <boost/test/unit_test.hpp>
+#include "ipcipher.hh"
+#include "misc.hh"
+
+using namespace boost;
+
+BOOST_AUTO_TEST_SUITE(test_ipcrypt_hh)
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt4)
+{
+  ComboAddress ca("127.0.0.1");
+  std::string key="0123456789ABCDEF";
+  auto encrypted = encryptCA(ca, key);
+
+  auto decrypted = decryptCA(encrypted, key);
+  BOOST_CHECK_EQUAL(ca.toString(), decrypted.toString());
+}
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt4_vector)
+{
+  vector<pair<string,string>>  tests{   // test vector from https://github.com/veorq/ipcrypt
+    {{"127.0.0.1"},{"114.62.227.59"}},
+    {{"8.8.8.8"},  {"46.48.51.50"}},
+    {{"1.2.3.4"},  {"171.238.15.199"}}};
+
+  std::string key="some 16-byte key";
+
+  for(const auto& p : tests) {
+    auto encrypted = encryptCA(ComboAddress(p.first), key);
+    BOOST_CHECK_EQUAL(encrypted.toString(), p.second);
+    auto decrypted = decryptCA(encrypted, key);
+    BOOST_CHECK_EQUAL(decrypted.toString(), p.first);
+  }
+
+  // test from Frank Denis' test.cc
+  ComboAddress ip("192.168.69.42"), out, dec;
+  string key2;
+  for(int n=0; n<16; ++n)
+    key2.append(1, (char)n+1);
+
+  for (unsigned int i = 0; i < 100000000UL; i++) {
+    out=encryptCA(ip, key2);
+    //    dec=decryptCA(out, key2);
+    // BOOST_CHECK(ip==dec);
+    ip=out;
+  }
+
+  ComboAddress expected("93.155.197.186");
+
+  BOOST_CHECK_EQUAL(ip.toString(), expected.toString());
+}
+
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt6)
+{
+  ComboAddress ca("::1");
+  std::string key="0123456789ABCDEF";
+  auto encrypted = encryptCA(ca, key);
+
+  auto decrypted = decryptCA(encrypted, key);
+  BOOST_CHECK_EQUAL(ca.toString(), decrypted.toString());
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/pdns/test-mplexer.cc b/pdns/test-mplexer.cc
new file mode 100644 (file)
index 0000000..8a7412f
--- /dev/null
@@ -0,0 +1,182 @@
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <thread>
+#include <boost/test/unit_test.hpp>
+
+#include "mplexer.hh"
+#include "misc.hh"
+
+BOOST_AUTO_TEST_SUITE(mplexer)
+
+BOOST_AUTO_TEST_CASE(test_MPlexer) {
+  auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+  BOOST_REQUIRE(mplexer != nullptr);
+
+  struct timeval now;
+  int ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+
+  std::vector<int> readyFDs;
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_CHECK_EQUAL(readyFDs.size(), 0);
+
+  auto timeouts = mplexer->getTimeouts(now);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+
+  int pipes[2];
+  int res = pipe(pipes);
+  BOOST_REQUIRE_EQUAL(res, 0);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(pipes[0]), true);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(pipes[1]), true);
+
+  /* let's declare a TTD that expired 5s ago */
+  struct timeval ttd = now;
+  ttd.tv_sec -= 5;
+
+  bool writeCBCalled = false;
+  auto writeCB = [](int fd, FDMultiplexer::funcparam_t param) {
+                        auto calledPtr = boost::any_cast<bool*>(param);
+                        BOOST_REQUIRE(calledPtr != nullptr);
+                        *calledPtr = true;
+                 };
+  mplexer->addWriteFD(pipes[1],
+                      writeCB,
+                      &writeCBCalled,
+                      &ttd);
+  /* we can't add it twice */
+  BOOST_CHECK_THROW(mplexer->addWriteFD(pipes[1],
+                                        writeCB,
+                                        &writeCBCalled,
+                                        &ttd),
+                    FDMultiplexerException);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), pipes[1]);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  /* no read timeouts */
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+  /* but we should have a write one */
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  /* can't remove from the wrong type of FD */
+  BOOST_CHECK_THROW(mplexer->removeReadFD(pipes[1]), FDMultiplexerException);
+  mplexer->removeWriteFD(pipes[1]);
+  /* can't remove a non-existing FD */
+  BOOST_CHECK_THROW(mplexer->removeWriteFD(pipes[0]), FDMultiplexerException);
+  BOOST_CHECK_THROW(mplexer->removeWriteFD(pipes[1]), FDMultiplexerException);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 0);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+
+  bool readCBCalled = false;
+  auto readCB = [](int fd, FDMultiplexer::funcparam_t param) {
+                        auto calledPtr = boost::any_cast<bool*>(param);
+                        BOOST_REQUIRE(calledPtr != nullptr);
+                        *calledPtr = true;
+                };
+  mplexer->addReadFD(pipes[0],
+                      readCB,
+                      &readCBCalled,
+                      &ttd);
+
+  /* not ready for reading yet */
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 0);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+  BOOST_CHECK_EQUAL(readCBCalled, false);
+
+  /* let's make the pipe readable */
+  BOOST_REQUIRE_EQUAL(write(pipes[1], "0", 1), 1);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), pipes[0]);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(readCBCalled, true);
+
+  /* add back the write FD */
+  mplexer->addWriteFD(pipes[1],
+                      writeCB,
+                      &writeCBCalled,
+                      &ttd);
+
+  /* both should be available */
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 2);
+
+  readCBCalled = false;
+  writeCBCalled = false;
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 2);
+  BOOST_CHECK_EQUAL(readCBCalled, true);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  /* both the read and write FD should be reported */
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[0]);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  struct timeval past = ttd;
+  /* so five seconds before the actual TTD */
+  past.tv_sec -= 5;
+
+  /* no read timeouts */
+  timeouts = mplexer->getTimeouts(past, false);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+  /* and we should not have a write one either */
+  timeouts = mplexer->getTimeouts(past, true);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+
+  /* update the timeouts to now, they should not be reported anymore */
+  mplexer->setReadTTD(pipes[0], now, 0);
+  mplexer->setWriteTTD(pipes[1], now, 0);
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 0);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 0);
+
+  /* put it back into the past */
+  mplexer->setReadTTD(pipes[0], now, -5);
+  mplexer->setWriteTTD(pipes[1], now, -5);
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[0]);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  mplexer->removeReadFD(pipes[0]);
+  mplexer->removeWriteFD(pipes[1]);
+
+  /* clean up */
+  close(pipes[0]);
+  close(pipes[1]);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
index e2e1066e19a8d2cdf51fea962fc8387cab3645d6..bd16702662aed97e411faa048d24cbc721f8963b 100644 (file)
@@ -156,9 +156,11 @@ try
     DNSPacket r(false);
     r.parse((char*)&pak[0], pak.size());
 
-    /* this step is necessary to get a valid hash */
-    DNSPacket cached(false);
-    g_PC->get(&q, &cached);
+    /* this step is necessary to get a valid hash
+       we directly compute the hash instead of querying the
+       cache because 1/ it's faster 2/ no deferred-lookup issues
+    */
+    q.setHash(g_PC->canHashPacket(q.getString()));
 
     const unsigned int maxTTL = 3600;
     g_PC->insert(&q, &r, maxTTL);
@@ -212,6 +214,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
       pthread_join(tid[i], &res);
 
     BOOST_CHECK_EQUAL(PC.size() + S.read("deferred-packetcache-inserts"), 400000);
+    BOOST_CHECK_EQUAL(S.read("deferred-packetcache-lookup"), 0);
     BOOST_CHECK_SMALL(1.0*S.read("deferred-packetcache-inserts"), 10000.0);
 
     for(int i=0; i < 4; ++i)
@@ -224,9 +227,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
     cerr<<"Hits: "<<S.read("packetcache-hit")<<endl;
     cerr<<"Deferred inserts: "<<S.read("deferred-packetcache-inserts")<<endl;
     cerr<<"Deferred lookups: "<<S.read("deferred-packetcache-lookup")<<endl;
+    cerr<<g_PCmissing<<endl;
+    cerr<<PC.size()<<endl;
 */
+
     BOOST_CHECK_EQUAL(g_PCmissing + S.read("packetcache-hit"), 400000);
-    BOOST_CHECK_GT(S.read("deferred-packetcache-inserts") + S.read("deferred-packetcache-lookup"), g_PCmissing);
+    BOOST_CHECK_EQUAL(S.read("deferred-packetcache-inserts") + S.read("deferred-packetcache-lookup"), g_PCmissing);
   }
   catch(PDNSException& e) {
     cerr<<"Had error: "<<e.reason<<endl;
index 979cc70851bf3be2f2c265649b048499536fd728..a168559e178505720390c4f6c9957c6f8bc8ce17 100644 (file)
 #include "dns.hh"
 #include "base64.hh"
 #include "json.hh"
+#include "uuid-utils.hh"
 #include <yahttp/router.hpp>
 
 json11::Json HttpRequest::json()
 {
   string err;
   if(this->body.empty()) {
-    g_log<<Logger::Debug<<"HTTP: JSON document expected in request body, but body was empty" << endl;
+    g_log<<Logger::Debug<<logprefix<<"JSON document expected in request body, but body was empty" << endl;
     throw HttpBadRequestException();
   }
   json11::Json doc = json11::Json::parse(this->body, err);
   if (doc.is_null()) {
-    g_log<<Logger::Debug<<"HTTP: parsing of JSON document failed:" << err << endl;
+    g_log<<Logger::Debug<<logprefix<<"parsing of JSON document failed:" << err << endl;
     throw HttpBadRequestException();
   }
   return doc;
@@ -130,13 +131,13 @@ static void apiWrapper(WebServer::HandlerFunction handler, HttpRequest* req, Htt
   resp->headers["access-control-allow-origin"] = "*";
 
   if (apikey.empty()) {
-    g_log<<Logger::Error<<"HTTP API Request \"" << req->url.path << "\": Authentication failed, API Key missing in config" << endl;
+    g_log<<Logger::Error<<req->logprefix<<"HTTP API Request \"" << req->url.path << "\": Authentication failed, API Key missing in config" << endl;
     throw HttpUnauthorizedException("X-API-Key");
   }
   bool auth_ok = req->compareHeader("x-api-key", apikey) || req->getvars["api-key"] == apikey;
   
   if (!auth_ok) {
-    g_log<<Logger::Error<<"HTTP Request \"" << req->url.path << "\": Authentication by API Key failed" << endl;
+    g_log<<Logger::Error<<req->logprefix<<"HTTP Request \"" << req->url.path << "\": Authentication by API Key failed" << endl;
     throw HttpUnauthorizedException("X-API-Key");
   }
 
@@ -178,7 +179,7 @@ static void webWrapper(WebServer::HandlerFunction handler, HttpRequest* req, Htt
   if (!password.empty()) {
     bool auth_ok = req->compareAuthorization(password);
     if (!auth_ok) {
-      g_log<<Logger::Debug<<"HTTP Request \"" << req->url.path << "\": Web Authentication failed" << endl;
+      g_log<<Logger::Debug<<req->logprefix<<"HTTP Request \"" << req->url.path << "\": Web Authentication failed" << endl;
       throw HttpUnauthorizedException("Basic");
     }
   }
@@ -204,11 +205,11 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
 
   try {
     if (!req.complete) {
-      g_log<<Logger::Debug<<"HTTP: Incomplete request" << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"Incomplete request" << endl;
       throw HttpBadRequestException();
     }
 
-    g_log<<Logger::Debug<<"HTTP: Handling request \"" << req.url.path << "\"" << endl;
+    g_log<<Logger::Debug<<req.logprefix<<"Handling request \"" << req.url.path << "\"" << endl;
 
     YaHTTP::strstr_map_t::iterator header;
 
@@ -223,33 +224,34 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
 
     YaHTTP::THandlerFunction handler;
     if (!YaHTTP::Router::Route(&req, handler)) {
-      g_log<<Logger::Debug<<"HTTP: No route found for \"" << req.url.path << "\"" << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"No route found for \"" << req.url.path << "\"" << endl;
       throw HttpNotFoundException();
     }
 
     try {
       handler(&req, &resp);
-      g_log<<Logger::Debug<<"HTTP: Result for \"" << req.url.path << "\": " << resp.status << ", body length: " << resp.body.size() << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"Result for \"" << req.url.path << "\": " << resp.status << ", body length: " << resp.body.size() << endl;
     }
     catch(HttpException&) {
       throw;
     }
     catch(PDNSException &e) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": Exception: " << e.reason << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": Exception: " << e.reason << endl;
       throw HttpInternalServerErrorException();
     }
     catch(std::exception &e) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": STL Exception: " << e.what() << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": STL Exception: " << e.what() << endl;
       throw HttpInternalServerErrorException();
     }
     catch(...) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": Unknown Exception" << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": Unknown Exception" << endl;
       throw HttpInternalServerErrorException();
     }
   }
   catch(HttpException &e) {
     resp = e.response();
-    g_log<<Logger::Debug<<"HTTP: Error result for \"" << req.url.path << "\": " << resp.status << endl;
+    // TODO rm this logline?
+    g_log<<Logger::Debug<<req.logprefix<<"Error result for \"" << req.url.path << "\": " << resp.status << endl;
     string what = YaHTTP::Utility::status2text(resp.status);
     if(req.accept_html) {
       resp.headers["Content-Type"] = "text/html; charset=utf-8";
@@ -276,49 +278,129 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
   }
 }
 
-void WebServer::serveConnection(std::shared_ptr<Socket> client) const
-try {
-  HttpRequest req;
-  YaHTTP::AsyncRequestLoader yarl;
-  yarl.initialize(&req);
-  int timeout = 5;
-  client->setNonBlocking();
+void WebServer::logRequest(const HttpRequest& req, const ComboAddress& remote) const {
+  if (d_loglevel >= WebServer::LogLevel::Detailed) {
+    auto logprefix = req.logprefix;
+    g_log<<Logger::Notice<<logprefix<<"Request details:"<<endl;
 
-  try {
-    while(!req.complete) {
-      int bytes;
-      char buf[1024];
-      bytes = client->readWithTimeout(buf, sizeof(buf), timeout);
-      if (bytes > 0) {
-        string data = string(buf, bytes);
-        req.complete = yarl.feed(data);
-      } else {
-        // read error OR EOF
-        break;
+    bool first = true;
+    for (const auto& r : req.getvars) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" GET params:"<<endl;
       }
+      g_log<<Logger::Notice<<logprefix<<"  "<<r.first<<": "<<r.second<<endl;
     }
-    yarl.finalize();
-  } catch (YaHTTP::ParseError &e) {
-    // request stays incomplete
-  }
 
-  HttpResponse resp;
-  WebServer::handleRequest(req, resp);
-  ostringstream ss;
-  resp.write(ss);
-  string reply = ss.str();
+    first = true;
+    for (const auto& r : req.postvars) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" POST params:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<r.first<<": "<<r.second<<endl;
+    }
 
-  client->writenWithTimeout(reply.c_str(), reply.size(), timeout);
-}
-catch(PDNSException &e) {
-  g_log<<Logger::Error<<"HTTP Exception: "<<e.reason<<endl;
+    first = true;
+    for (const auto& h : req.headers) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" Headers:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<h.first<<": "<<h.second<<endl;
+    }
+
+    if (req.body.empty()) {
+      g_log<<Logger::Notice<<logprefix<<" No body"<<endl;
+    } else {
+      g_log<<Logger::Notice<<logprefix<<" Full body: "<<endl;
+      g_log<<Logger::Notice<<logprefix<<"  "<<req.body<<endl;
+    }
+  }
 }
-catch(std::exception &e) {
-  if(strstr(e.what(), "timeout")==0)
-    g_log<<Logger::Error<<"HTTP STL Exception: "<<e.what()<<endl;
+
+void WebServer::logResponse(const HttpResponse& resp, const ComboAddress& remote, const string& logprefix) const {
+  if (d_loglevel >= WebServer::LogLevel::Detailed) {
+    g_log<<Logger::Notice<<logprefix<<"Response details:"<<endl;
+    bool first = true;
+    for (const auto& h : resp.headers) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" Headers:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<h.first<<": "<<h.second<<endl;
+    }
+    if (resp.body.empty()) {
+      g_log<<Logger::Notice<<logprefix<<" No body"<<endl;
+    } else {
+      g_log<<Logger::Notice<<logprefix<<" Full body: "<<endl;
+      g_log<<Logger::Notice<<logprefix<<"  "<<resp.body<<endl;
+    }
+  }
 }
-catch(...) {
-  g_log<<Logger::Error<<"HTTP: Unknown exception"<<endl;
+
+void WebServer::serveConnection(std::shared_ptr<Socket> client) const {
+  const string logprefix = d_logprefix + to_string(getUniqueID()) + " ";
+
+  HttpRequest req(logprefix);
+  HttpResponse resp;
+  ComboAddress remote;
+  string reply;
+
+  try {
+    YaHTTP::AsyncRequestLoader yarl;
+    yarl.initialize(&req);
+    int timeout = 5;
+    client->setNonBlocking();
+
+    try {
+      while(!req.complete) {
+        int bytes;
+        char buf[1024];
+        bytes = client->readWithTimeout(buf, sizeof(buf), timeout);
+        if (bytes > 0) {
+          string data = string(buf, bytes);
+          req.complete = yarl.feed(data);
+        } else {
+          // read error OR EOF
+          break;
+        }
+      }
+      yarl.finalize();
+    } catch (YaHTTP::ParseError &e) {
+      // request stays incomplete
+      g_log<<Logger::Warning<<logprefix<<"Unable to parse request: "<<e.what()<<endl;
+    }
+
+    if (d_loglevel >= WebServer::LogLevel::None) {
+      client->getRemote(remote);
+    }
+
+    logRequest(req, remote);
+
+    WebServer::handleRequest(req, resp);
+    ostringstream ss;
+    resp.write(ss);
+    reply = ss.str();
+
+    logResponse(resp, remote, logprefix);
+
+    client->writenWithTimeout(reply.c_str(), reply.size(), timeout);
+  }
+  catch(PDNSException &e) {
+    g_log<<Logger::Error<<logprefix<<"HTTP Exception: "<<e.reason<<endl;
+  }
+  catch(std::exception &e) {
+    if(strstr(e.what(), "timeout")==0)
+      g_log<<Logger::Error<<logprefix<<"HTTP STL Exception: "<<e.what()<<endl;
+  }
+  catch(...) {
+    g_log<<Logger::Error<<logprefix<<"Unknown exception"<<endl;
+  }
+
+  if (d_loglevel >= WebServer::LogLevel::Normal) {
+    g_log<<Logger::Notice<<logprefix<<remote<<" \""<<req.method<<" "<<req.url.path<<" HTTP/"<<req.versionStr(req.version)<<"\" "<<resp.status<<" "<<reply.size()<<endl;
+  }
 }
 
 WebServer::WebServer(const string &listenaddress, int port) :
@@ -332,10 +414,10 @@ void WebServer::bind()
 {
   try {
     d_server = createServer();
-    g_log<<Logger::Warning<<"Listening for HTTP requests on "<<d_server->d_local.toStringWithPort()<<endl;
+    g_log<<Logger::Warning<<d_logprefix<<"Listening for HTTP requests on "<<d_server->d_local.toStringWithPort()<<endl;
   }
   catch(NetworkError &e) {
-    g_log<<Logger::Error<<"Listening on HTTP socket failed: "<<e.what()<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"Listening on HTTP socket failed: "<<e.what()<<endl;
     d_server = nullptr;
   }
 }
@@ -357,28 +439,28 @@ void WebServer::go()
         } else {
           ComboAddress remote;
           if (client->getRemote(remote))
-            g_log<<Logger::Error<<"Webserver closing socket: remote ("<< remote.toString() <<") does not match the set ACL("<<d_acl.toString()<<")"<<endl;
+            g_log<<Logger::Error<<d_logprefix<<"Webserver closing socket: remote ("<< remote.toString() <<") does not match the set ACL("<<d_acl.toString()<<")"<<endl;
         }
       }
       catch(PDNSException &e) {
-        g_log<<Logger::Error<<"PDNSException while accepting a connection in main webserver thread: "<<e.reason<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"PDNSException while accepting a connection in main webserver thread: "<<e.reason<<endl;
       }
       catch(std::exception &e) {
-        g_log<<Logger::Error<<"STL Exception while accepting a connection in main webserver thread: "<<e.what()<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"STL Exception while accepting a connection in main webserver thread: "<<e.what()<<endl;
       }
       catch(...) {
-        g_log<<Logger::Error<<"Unknown exception while accepting a connection in main webserver thread"<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"Unknown exception while accepting a connection in main webserver thread"<<endl;
       }
     }
   }
   catch(PDNSException &e) {
-    g_log<<Logger::Error<<"PDNSException in main webserver thread: "<<e.reason<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"PDNSException in main webserver thread: "<<e.reason<<endl;
   }
   catch(std::exception &e) {
-    g_log<<Logger::Error<<"STL Exception in main webserver thread: "<<e.what()<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"STL Exception in main webserver thread: "<<e.what()<<endl;
   }
   catch(...) {
-    g_log<<Logger::Error<<"Unknown exception in main webserver thread"<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"Unknown exception in main webserver thread"<<endl;
   }
   _exit(1);
 }
index 5fa351a667e7acfa44064d4662a4e84a4bc16302..500f157044fb358e0607b5b067ae3c6343a8d214 100644 (file)
@@ -34,11 +34,12 @@ class WebServer;
 
 class HttpRequest : public YaHTTP::Request {
 public:
-  HttpRequest() : YaHTTP::Request(), accept_json(false), accept_html(false), complete(false) { };
+  HttpRequest(const string& logprefix="") : YaHTTP::Request(), accept_json(false), accept_html(false), complete(false), logprefix(logprefix) { };
 
   bool accept_json;
   bool accept_html;
   bool complete;
+  string logprefix;
   json11::Json json();
 
   // checks password _only_.
@@ -184,8 +185,43 @@ public:
   void registerApiHandler(const string& url, HandlerFunction handler);
   void registerWebHandler(const string& url, HandlerFunction handler);
 
+  enum class LogLevel : uint8_t {
+    None = 0,                // No logs from requests at all
+    Normal = 10,             // A "common log format"-like line e.g. '127.0.0.1 "GET /apache_pb.gif HTTP/1.0" 200 2326'
+    Detailed = 20,           // The full request headers and body, and the full response headers and body
+  };
+
+  void setLogLevel(const string& level) {
+    if (level == "none") {
+      d_loglevel = LogLevel::None;
+      return;
+    }
+
+    if (level == "normal") {
+      d_loglevel = LogLevel::Normal;
+      return;
+    }
+
+    if (level == "detailed") {
+      d_loglevel = LogLevel::Detailed;
+      return;
+    }
+
+    throw PDNSException("Unknown webserver log level: " + level);
+  }
+
+  void setLogLevel(const LogLevel level) {
+    d_loglevel = level;
+  };
+
+  LogLevel getLogLevel() {
+    return d_loglevel;
+  };
+
 protected:
   void registerBareHandler(const string& url, HandlerFunction handler);
+  void logRequest(const HttpRequest& req, const ComboAddress& remote) const;
+  void logResponse(const HttpResponse& resp, const ComboAddress& remote, const string& logprefix) const;
 
   virtual std::shared_ptr<Server> createServer() {
     return std::make_shared<Server>(d_listenaddress, d_port);
@@ -203,6 +239,11 @@ protected:
   bool d_registerWebHandlerCalled{false};
 
   NetmaskGroup d_acl;
+
+  const string d_logprefix = "[webserver] ";
+
+  // Describes the amount of logging the webserver does
+  WebServer::LogLevel d_loglevel{WebServer::LogLevel::Detailed};
 };
 
 #endif /* WEBSERVER_HH */
index d2c3f9ae0db3a53decb4afdb72d649b9bcc34964..9dfa4babaced99f9926c521d73d2b4197c906609 100644 (file)
@@ -72,6 +72,7 @@ AuthWebServer::AuthWebServer() :
     d_ws = new WebServer(arg()["webserver-address"], arg().asNum("webserver-port"));
     d_ws->setApiKey(arg()["api-key"]);
     d_ws->setPassword(arg()["webserver-password"]);
+    d_ws->setLogLevel(arg()["webserver-loglevel"]);
 
     NetmaskGroup acl;
     acl.toMasks(::arg()["webserver-allow-from"]);
@@ -282,9 +283,10 @@ void AuthWebServer::indexfunction(HttpRequest* req, HttpResponse* resp)
 
   ret<<"Total queries: "<<S.read("udp-queries")<<". Question/answer latency: "<<S.read("latency")/1000.0<<"ms</p><br>"<<endl;
   if(req->getvars["ring"].empty()) {
-    vector<string>entries=S.listRings();
-    for(vector<string>::const_iterator i=entries.begin();i!=entries.end();++i)
-      printtable(ret,*i,S.getRingTitle(*i));
+    auto entries = S.listRings();
+    for(const auto &i: entries) {
+      printtable(ret, i, S.getRingTitle(i));
+    }
 
     printvars(ret);
     if(arg().mustDo("webserver-print-arguments"))
@@ -1707,6 +1709,11 @@ static void apiServerZoneDetail(HttpRequest* req, HttpResponse* resp) {
     if(!di.backend->deleteDomain(zonename))
       throw ApiException("Deleting domain '"+zonename.toString()+"' failed: backend delete failed/unsupported");
 
+    // clear caches
+    DNSSECKeeper dk(&B);
+    dk.clearCaches(zonename);
+    purgeAuthCaches(zonename.toString() + "$");
+
     // empty body on success
     resp->body = "";
     resp->status = 204; // No Content: declare that the zone is gone now
index dfb2368c4f55056728a3dba7fda2c7b7bb92a8be..bf4038a2a51a92fbd10aa9489134dd855368e34e 100644 (file)
@@ -41,6 +41,7 @@
 #include "ext/incbin/incbin.h"
 #include "rec-lua-conf.hh"
 #include "rpzloader.hh"
+#include "uuid-utils.hh"
 
 extern thread_local FDMultiplexer* t_fdm;
 
@@ -457,6 +458,7 @@ RecursorWebServer::RecursorWebServer(FDMultiplexer* fdm)
   d_ws = new AsyncWebServer(fdm, arg()["webserver-address"], arg().asNum("webserver-port"));
   d_ws->setApiKey(arg()["api-key"]);
   d_ws->setPassword(arg()["webserver-password"]);
+  d_ws->setLogLevel(arg()["webserver-loglevel"]);
 
   NetmaskGroup acl;
   acl.toMasks(::arg()["webserver-allow-from"]);
@@ -625,50 +627,69 @@ void AsyncServer::newConnection()
 }
 
 // This is an entry point from FDM, so it needs to catch everything.
-void AsyncWebServer::serveConnection(std::shared_ptr<Socket> client) const
-try {
-  HttpRequest req;
-  YaHTTP::AsyncRequestLoader yarl;
-  yarl.initialize(&req);
-  client->setNonBlocking();
-
-  string data;
+void AsyncWebServer::serveConnection(std::shared_ptr<Socket> client) const {
+  const string logprefix = d_logprefix + to_string(getUniqueID()) + " ";
+
+  HttpRequest req(logprefix);
+  HttpResponse resp;
+  ComboAddress remote;
+  string reply;
+
   try {
-    while(!req.complete) {
-      int bytes = arecvtcp(data, 16384, client.get(), true);
-      if (bytes > 0) {
-        req.complete = yarl.feed(data);
-      } else {
-        // read error OR EOF
-        break;
+    YaHTTP::AsyncRequestLoader yarl;
+    yarl.initialize(&req);
+    client->setNonBlocking();
+
+    string data;
+    try {
+      while(!req.complete) {
+        int bytes = arecvtcp(data, 16384, client.get(), true);
+        if (bytes > 0) {
+          req.complete = yarl.feed(data);
+        } else {
+          // read error OR EOF
+          break;
+        }
       }
+      yarl.finalize();
+    } catch (YaHTTP::ParseError &e) {
+      // request stays incomplete
+      g_log<<Logger::Warning<<logprefix<<"Unable to parse request: "<<e.what()<<endl;
+    }
+
+    if (d_loglevel >= WebServer::LogLevel::None) {
+      client->getRemote(remote);
     }
-    yarl.finalize();
-  } catch (YaHTTP::ParseError &e) {
-    // request stays incomplete
+
+    logRequest(req, remote);
+
+    WebServer::handleRequest(req, resp);
+    ostringstream ss;
+    resp.write(ss);
+    reply = ss.str();
+
+    logResponse(resp, remote, logprefix);
+
+    // now send the reply
+    if (asendtcp(reply, client.get()) == -1 || reply.empty()) {
+      g_log<<Logger::Error<<logprefix<<"Failed sending reply to HTTP client"<<endl;
+    }
+  }
+  catch(PDNSException &e) {
+    g_log<<Logger::Error<<logprefix<<"Exception: "<<e.reason<<endl;
+  }
+  catch(std::exception &e) {
+    if(strstr(e.what(), "timeout")==0)
+      g_log<<Logger::Error<<logprefix<<"STL Exception: "<<e.what()<<endl;
+  }
+  catch(...) {
+    g_log<<Logger::Error<<logprefix<<"Unknown exception"<<endl;
   }
 
-  HttpResponse resp;
-  handleRequest(req, resp);
-  ostringstream ss;
-  resp.write(ss);
-  data = ss.str();
-
-  // now send the reply
-  if (asendtcp(data, client.get()) == -1 || data.empty()) {
-    g_log<<Logger::Error<<"Failed sending reply to HTTP client"<<endl;
+  if (d_loglevel >= WebServer::LogLevel::Normal) {
+    g_log<<Logger::Notice<<logprefix<<remote<<" \""<<req.method<<" "<<req.url.path<<" HTTP/"<<req.versionStr(req.version)<<"\" "<<resp.status<<" "<<reply.size()<<endl;
   }
 }
-catch(PDNSException &e) {
-  g_log<<Logger::Error<<"HTTP Exception: "<<e.reason<<endl;
-}
-catch(std::exception &e) {
-  if(strstr(e.what(), "timeout")==0)
-    g_log<<Logger::Error<<"HTTP STL Exception: "<<e.what()<<endl;
-}
-catch(...) {
-  g_log<<Logger::Error<<"HTTP: Unknown exception"<<endl;
-}
 
 void AsyncWebServer::go() {
   if (!d_server)
index e1df6cfa7df44a5631334af0776505a88560a587..70527759b3e4fcf1207f3f288e3e19a5c86cae22 100644 (file)
@@ -214,6 +214,7 @@ class DNSDistTest(unittest.TestCase):
         ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
             sock.bind(("127.0.0.1", port))
@@ -304,6 +305,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTCPConnection(cls, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -313,6 +315,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -385,6 +388,7 @@ class DNSDistTest(unittest.TestCase):
             for response in responses:
                 cls._toResponderQueue.put(response, True, timeout)
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -475,6 +479,7 @@ class DNSDistTest(unittest.TestCase):
         ourNonce = libnacl.utils.rand_nonce()
         theirNonce = None
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
index 6ad4ff3d674185a3e7ff34e263fb3e61ce69561a..7fbe25b04b7470821c56380ecbf9e670e62850c7 100644 (file)
@@ -2,7 +2,7 @@ dnspython>=1.11,<1.16.0
 nose>=1.3.7
 libnacl>=1.4.3
 requests>=2.1.0
-protobuf>=2.5,<3.0; sys_platform != 'darwin'
-protobuf>=3.0; sys_platform == 'darwin'
+protobuf>=2.5,<3.0; sys_platform != 'darwin' and sys_platform != 'openbsd6'
+protobuf>=3.0; sys_platform == 'darwin' or sys_platform == 'openbsd6'
 pysnmp>=4.3.4
 future>=0.17.1
index 45e34dfbe7a8bcd9f860735e1b1f235657e4c708..e83d947c85d8148d71178a001d757e9b74b741c1 100755 (executable)
@@ -32,6 +32,9 @@ if [ "${PDNS_DEBUG}" = "YES" ]; then
   set -x
 fi
 
+rm -f ca.key ca.pem ca.srl server.csr server.key server.pem server.chain
+rm -rf configs/*
+
 # Generate a new CA
 openssl req -new -x509 -days 1 -extensions v3_ca -keyout ca.key -out ca.pem -nodes -config configCA.conf
 # Generate a new server certificate request
@@ -39,7 +42,7 @@ openssl req -new -newkey rsa:2048 -nodes -keyout server.key -out server.csr -con
 # Sign the server cert
 openssl x509 -req -days 1 -CA ca.pem -CAkey ca.key -CAcreateserial -in server.csr -out server.pem
 # Generate a chain
-cat server.pem ca.pem >> server.chain
+cat server.pem ca.pem > server.chain
 
 if ! nosetests --with-xunit $@; then
     for log in configs/*.log; do
index 0b358ca7e64bc9ba41229156c4bf1c6f614803e0..01bc901baea5eacfff5f366b4bea6a74ef22cdb2 100644 (file)
@@ -1 +1 @@
-addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
index 00578a3e779b0fefcad922c0b99a2fcbbceb3cfb..8ec87804a7f352e00a4a611830e668d210db6149 100644 (file)
@@ -114,7 +114,7 @@ class TestAPIBasics(DNSDistTest):
             self.assertTrue(server['state'] in ['up', 'down', 'UP', 'DOWN'])
 
         for frontend in content['frontends']:
-            for key in ['id', 'address', 'udp', 'tcp', 'queries']:
+            for key in ['id', 'address', 'udp', 'tcp', 'type', 'queries']:
                 self.assertIn(key, frontend)
 
             for key in ['id', 'queries']:
@@ -226,6 +226,7 @@ class TestAPIBasics(DNSDistTest):
             values[entry['name']] = entry['value']
 
         expected = ['responses', 'servfail-responses', 'queries', 'acl-drops',
+                    'frontend-noerror', 'frontend-nxdomain', 'frontend-servfail',
                     'rule-drop', 'rule-nxdomain', 'rule-refused', 'self-answered', 'downstream-timeouts',
                     'downstream-send-errors', 'trunc-failures', 'no-policy', 'latency0-1',
                     'latency1-10', 'latency10-50', 'latency50-100', 'latency100-1000',
@@ -255,6 +256,7 @@ class TestAPIBasics(DNSDistTest):
         content = r.json()
 
         expected = ['responses', 'servfail-responses', 'queries', 'acl-drops',
+                    'frontend-noerror', 'frontend-nxdomain', 'frontend-servfail',
                     'rule-drop', 'rule-nxdomain', 'rule-refused', 'self-answered', 'downstream-timeouts',
                     'downstream-send-errors', 'trunc-failures', 'no-policy', 'latency0-1',
                     'latency1-10', 'latency10-50', 'latency50-100', 'latency100-1000',
@@ -358,24 +360,36 @@ class TestAPIWritable(DNSDistTest):
         self.assertEquals(r.status_code, 200)
         self.assertTrue(r.json())
         content = r.json()
-        self.assertEquals(content['value'], newACL)
+        acl = content['value']
+        acl.sort()
+        self.assertEquals(acl, newACL)
 
         r = requests.get(url, headers=headers, timeout=self._webTimeout)
         self.assertTrue(r)
         self.assertEquals(r.status_code, 200)
         self.assertTrue(r.json())
         content = r.json()
-        self.assertEquals(content['value'], newACL)
+        acl = content['value']
+        acl.sort()
+        self.assertEquals(acl, newACL)
 
         configFile = self._APIWriteDir + '/' + 'acl.conf'
         self.assertTrue(os.path.isfile(configFile))
         fileContent = None
         with open(configFile, 'rt') as f:
-            fileContent = f.read()
-
-        self.assertEquals(fileContent, """-- Generated by the REST API, DO NOT EDIT
-setACL({"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"})
-""")
+            header = f.readline()
+            body = f.readline()
+
+        self.assertEquals(header, """-- Generated by the REST API, DO NOT EDIT\n""")
+
+        self.assertIn(body, {
+            """setACL({"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"})\n""",
+            """setACL({"192.0.2.0/24", "203.0.113.0/24", "198.51.100.0/24"})\n""",
+            """setACL({"198.51.100.0/24", "192.0.2.0/24", "203.0.113.0/24"})\n""",
+            """setACL({"198.51.100.0/24", "203.0.113.0/24", "192.0.2.0/24"})\n""",
+            """setACL({"203.0.113.0/24", "192.0.2.0/24", "198.51.100.0/24"})\n""",
+            """setACL({"203.0.113.0/24", "198.51.100.0/24", "192.0.2.0/24"})\n"""
+        })
 
 class TestAPICustomHeaders(DNSDistTest):
 
index 4bb4c543f7fe43f9b92fe9ac08189d1e227609ce..6bc5004614ad49caf2460b98518f2ddc614063b0 100644 (file)
@@ -165,111 +165,111 @@ class TestAXFR(DNSDistTest):
         self.assertEqual(query, receivedQuery)
         self.assertEqual(len(receivedResponses), len(responses))
 
-    def testFourNoFirstSOAAXFR(self):
-        """
-        AXFR: Four messages, no SOA in the first one
-        """
-        name = 'fournosoainfirst.axfr.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AXFR', 'IN')
-        responses = []
-        soa = dns.rrset.from_text(name,
-                                  60,
-                                  dns.rdataclass.IN,
-                                  dns.rdatatype.SOA,
-                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
-        response = dns.message.make_response(query)
-        response.answer.append(dns.rrset.from_text(name,
-                                                   60,
-                                                   dns.rdataclass.IN,
-                                                   dns.rdatatype.A,
-                                                   '192.0.2.1'))
-        responses.append(response)
+    def testFourNoFirstSOAAXFR(self):
+        """
+        AXFR: Four messages, no SOA in the first one
+        """
+        name = 'fournosoainfirst.axfr.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AXFR', 'IN')
+        responses = []
+        soa = dns.rrset.from_text(name,
+                                  60,
+                                  dns.rdataclass.IN,
+                                  dns.rdatatype.SOA,
+                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
+        response = dns.message.make_response(query)
+        response.answer.append(dns.rrset.from_text(name,
+                                                   60,
+                                                   dns.rdataclass.IN,
+                                                   dns.rdatatype.A,
+                                                   '192.0.2.1'))
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(soa)
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(soa)
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text('dummy.' + name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text('dummy.' + name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.TXT,
-                                    'dummy')
-        response.answer.append(rrset)
-        response.answer.append(soa)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.TXT,
+                                    'dummy')
+        response.answer.append(rrset)
+        response.answer.append(soa)
+        responses.append(response)
 
-        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
-        receivedQuery.id = query.id
-        self.assertEqual(query, receivedQuery)
-        self.assertEqual(len(receivedResponses), 1)
+        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(len(receivedResponses), 1)
 
-    def testFourLastSOAInSecondAXFR(self):
-        """
-        AXFR: Four messages, SOA in the first one and the second one
-        """
-        name = 'foursecondsoainsecond.axfr.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AXFR', 'IN')
-        responses = []
-        soa = dns.rrset.from_text(name,
-                                  60,
-                                  dns.rdataclass.IN,
-                                  dns.rdatatype.SOA,
-                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
+    def testFourLastSOAInSecondAXFR(self):
+        """
+        AXFR: Four messages, SOA in the first one and the second one
+        """
+        name = 'foursecondsoainsecond.axfr.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AXFR', 'IN')
+        responses = []
+        soa = dns.rrset.from_text(name,
+                                  60,
+                                  dns.rdataclass.IN,
+                                  dns.rdatatype.SOA,
+                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
 
-        response = dns.message.make_response(query)
-        response.answer.append(soa)
-        response.answer.append(dns.rrset.from_text(name,
-                                                   60,
-                                                   dns.rdataclass.IN,
-                                                   dns.rdatatype.A,
-                                                   '192.0.2.1'))
-        responses.append(response)
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        response.answer.append(dns.rrset.from_text(name,
+                                                   60,
+                                                   dns.rdataclass.IN,
+                                                   dns.rdatatype.A,
+                                                   '192.0.2.1'))
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        response.answer.append(soa)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text('dummy.' + name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text('dummy.' + name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.TXT,
-                                    'dummy')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.TXT,
+                                    'dummy')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
-        receivedQuery.id = query.id
-        self.assertEqual(query, receivedQuery)
-        self.assertEqual(len(receivedResponses), 2)
+        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(len(receivedResponses), 2)
index ea943bb00cef825da4181b2c639a560202807037..4a4c7e55535a62170cc1429b3518eed36561fd46 100644 (file)
@@ -34,19 +34,14 @@ class TestAdvancedAllow(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedAllowDropped(self):
         """
@@ -56,11 +51,10 @@ class TestAdvancedAllow(DNSDistTest):
         """
         name = 'notallowed.advanced.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
 
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
 
 class TestAdvancedFixupCase(DNSDistTest):
 
@@ -91,20 +85,14 @@ class TestAdvancedFixupCase(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestAdvancedRemoveRD(DNSDistTest):
 
@@ -133,19 +121,14 @@ class TestAdvancedRemoveRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepRD(self):
         """
@@ -165,20 +148,14 @@ class TestAdvancedRemoveRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedAddCD(DNSDistTest):
 
@@ -208,19 +185,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedSetCDViaAction(self):
         """
@@ -242,19 +214,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepNoCD(self):
         """
@@ -274,19 +241,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedClearRD(DNSDistTest):
 
@@ -316,19 +278,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedClearRDViaAction(self):
         """
@@ -350,19 +307,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepRD(self):
         """
@@ -382,19 +334,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 
 class TestAdvancedACL(DNSDistTest):
@@ -415,11 +362,10 @@ class TestAdvancedACL(DNSDistTest):
         name = 'tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestAdvancedDelay(DNSDistTest):
 
@@ -506,7 +452,7 @@ class TestAdvancedTruncateAnyAndTCP(DNSDistTest):
 class TestAdvancedAndNot(DNSDistTest):
 
     _config_template = """
-    addAction(AndRule({NotRule(QTypeRule("A")), TCPRule(false)}), RCodeAction(dnsdist.NOTIMP))
+    addAction(AndRule({NotRule(QTypeRule("A")), TCPRule(false)}), RCodeAction(DNSRCode.NOTIMP))
     newServer{address="127.0.0.1:%s"}
     """
     def testAOverUDPReturnsNotImplementedCanary(self):
@@ -528,19 +474,14 @@ class TestAdvancedAndNot(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testAOverUDPReturnsNotImplemented(self):
         """
@@ -578,7 +519,7 @@ class TestAdvancedAndNot(DNSDistTest):
 class TestAdvancedOr(DNSDistTest):
 
     _config_template = """
-    addAction(OrRule({QTypeRule("A"), TCPRule(false)}), RCodeAction(dnsdist.NOTIMP))
+    addAction(OrRule({QTypeRule("A"), TCPRule(false)}), RCodeAction(DNSRCode.NOTIMP))
     newServer{address="127.0.0.1:%s"}
     """
     def testAAAAOverUDPReturnsNotImplemented(self):
@@ -628,11 +569,10 @@ class TestAdvancedOr(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 
 class TestAdvancedLogAction(DNSDistTest):
@@ -690,24 +630,19 @@ class TestAdvancedDNSSEC(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
-        (_, receivedResponse) = self.sendUDPQuery(doquery, response)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(doquery, response)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(doquery, response)
+            self.assertEquals(receivedResponse, None)
 
 class TestAdvancedQClass(DNSDistTest):
 
@@ -723,10 +658,10 @@ class TestAdvancedQClass(DNSDistTest):
         name = 'qclasschaos.advanced.tests.powerdns.com.'
         query = dns.message.make_query(name, 'TXT', 'CHAOS')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testAdvancedQClassINAllow(self):
         """
@@ -743,19 +678,14 @@ class TestAdvancedQClass(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedOpcode(DNSDistTest):
 
@@ -772,10 +702,10 @@ class TestAdvancedOpcode(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         query.set_opcode(dns.opcode.NOTIFY)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testAdvancedOpcodeUpdateINAllow(self):
         """
@@ -793,19 +723,14 @@ class TestAdvancedOpcode(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedNonTerminalRule(DNSDistTest):
 
@@ -835,19 +760,14 @@ class TestAdvancedNonTerminalRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedStringOnlyServer(DNSDistTest):
 
@@ -869,19 +789,14 @@ class TestAdvancedStringOnlyServer(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedRestoreFlagsOnSelfResponse(DNSDistTest):
 
@@ -912,13 +827,11 @@ class TestAdvancedRestoreFlagsOnSelfResponse(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(response, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedQPS(DNSDistTest):
 
@@ -972,7 +885,7 @@ class TestAdvancedQPSNone(DNSDistTest):
 
     _config_template = """
     addAction("qpsnone.advanced.tests.powerdns.com", QPSAction(100))
-    addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -989,18 +902,17 @@ class TestAdvancedQPSNone(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedNMGRule(DNSDistTest):
 
     _config_template = """
     allowed = newNMG()
     allowed:addMask("192.0.2.1/32")
-    addAction(NotRule(NetmaskGroupRule(allowed)), RCodeAction(dnsdist.REFUSED))
+    addAction(NotRule(NetmaskGroupRule(allowed)), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -1016,17 +928,16 @@ class TestAdvancedNMGRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestDSTPortRule(DNSDistTest):
 
     _config_params = ['_dnsDistPort', '_testServerPort']
     _config_template = """
-    addAction(DSTPortRule(%d), RCodeAction(dnsdist.REFUSED))
+    addAction(DSTPortRule(%d), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -1043,16 +954,15 @@ class TestDSTPortRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedLabelsCountRule(DNSDistTest):
 
     _config_template = """
-    addAction(QNameLabelsCountRule(5,6), RCodeAction(dnsdist.REFUSED))
+    addAction(QNameLabelsCountRule(5,6), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -1071,19 +981,14 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # more than 6 labels, the query should be refused
         name = 'not.ok.labelscount.advanced.tests.powerdns.com.'
@@ -1091,11 +996,10 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
         # less than 5 labels, the query should be refused
         name = 'labelscountadvanced.tests.powerdns.com.'
@@ -1103,16 +1007,15 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedWireLengthRule(DNSDistTest):
 
     _config_template = """
-    addAction(QNameWireLengthRule(54,56), RCodeAction(dnsdist.REFUSED))
+    addAction(QNameWireLengthRule(54,56), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -1130,19 +1033,14 @@ class TestAdvancedWireLengthRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # too short, the query should be refused
         name = 'short.qnamewirelength.advanced.tests.powerdns.com.'
@@ -1150,11 +1048,10 @@ class TestAdvancedWireLengthRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
         # too long, the query should be refused
         name = 'toolongtobevalid.qnamewirelength.advanced.tests.powerdns.com.'
@@ -1162,11 +1059,10 @@ class TestAdvancedWireLengthRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedIncludeDir(DNSDistTest):
 
@@ -1190,19 +1086,14 @@ class TestAdvancedIncludeDir(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # this one should be refused
         name = 'notincludedir.advanced.tests.powerdns.com.'
@@ -1210,11 +1101,10 @@ class TestAdvancedIncludeDir(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedLuaDO(DNSDistTest):
 
@@ -1247,30 +1137,22 @@ class TestAdvancedLuaDO(DNSDistTest):
         doResponse.set_rcode(dns.rcode.NXDOMAIN)
 
         # without DO
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # with DO
-        (_, receivedResponse) = self.sendUDPQuery(queryWithDO, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        doResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, doResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(queryWithDO, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        doResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, doResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(queryWithDO, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            doResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, doResponse)
 
 class TestAdvancedLuaRefused(DNSDistTest):
 
@@ -1298,15 +1180,12 @@ class TestAdvancedLuaRefused(DNSDistTest):
         refusedResponse = dns.message.make_response(query)
         refusedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            refusedResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, refusedResponse)
 
 class TestAdvancedLuaActionReturnSyntax(DNSDistTest):
 
@@ -1334,15 +1213,12 @@ class TestAdvancedLuaActionReturnSyntax(DNSDistTest):
         refusedResponse = dns.message.make_response(query)
         refusedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            refusedResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, refusedResponse)
 
 class TestAdvancedLuaTruncated(DNSDistTest):
 
@@ -1475,7 +1351,7 @@ com.""")
 class TestAdvancedRD(DNSDistTest):
 
     _config_template = """
-    addAction(RDRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(RDRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -1488,11 +1364,10 @@ class TestAdvancedRD(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAdvancedNoRDAllowed(self):
         """
@@ -1503,15 +1378,12 @@ class TestAdvancedRD(DNSDistTest):
         query.flags &= ~dns.flags.RD
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalPort(DNSDistTest):
 
@@ -1541,11 +1413,10 @@ class TestAdvancedGetLocalPort(DNSDistTest):
                                     'port-was-{}.local-port.advanced.tests.powerdns.com.'.format(self._dnsDistPort))
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalPortOnAnyBind(DNSDistTest):
 
@@ -1576,11 +1447,10 @@ class TestAdvancedGetLocalPortOnAnyBind(DNSDistTest):
                                     'port-was-{}.local-port-any.advanced.tests.powerdns.com.'.format(self._dnsDistPort))
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
 
@@ -1614,11 +1484,10 @@ class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
                                     'address-was-127-0-0-1.local-address-any.advanced.tests.powerdns.com.')
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedLuaTempFailureTTL(DNSDistTest):
 
@@ -1655,12 +1524,14 @@ class TestAdvancedLuaTempFailureTTL(DNSDistTest):
                                     '::1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedEDNSOptionRule(DNSDistTest):
 
@@ -1679,10 +1550,10 @@ class TestAdvancedEDNSOptionRule(DNSDistTest):
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testReplied(self):
         """
@@ -1695,25 +1566,29 @@ class TestAdvancedEDNSOptionRule(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[], payload=512)
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
 
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # and with no EDNS at all
         query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
 
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedAllowHeaderOnly(DNSDistTest):
 
@@ -1775,11 +1650,11 @@ class TestAdvancedAllowHeaderOnly(DNSDistTest):
             self.assertEquals(query, receivedQuery)
             self.assertEquals(receivedResponse, response)
 
-class TestAdvancedEDNSVersionnRule(DNSDistTest):
+class TestAdvancedEDNSVersionRule(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addAction(EDNSVersionRule(0), ERCodeAction(dnsdist.BADVERS))
+    addAction(EDNSVersionRule(0), ERCodeAction(DNSRCode.BADVERS))
     """
 
     def testDropped(self):
index ea4b7b08b9a94d81d4e828cd60fd83fd3c0915cc..5dfda57d21467c8a49411f3564aeb5cc23c87b37 100644 (file)
@@ -9,15 +9,15 @@ class TestBasics(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     truncateTC(true)
-    addAction(AndRule{QTypeRule(dnsdist.ANY), TCPRule(false)}, TCAction())
-    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule{QTypeRule(DNSQType.ANY), TCPRule(false)}, TCAction())
+    addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(DNSRCode.REFUSED))
     mySMN = newSuffixMatchNode()
     mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com."))
-    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(dnsdist.NOTIMP))
+    addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(DNSRCode.NOTIMP))
     addAction(makeRule("drop.test.powerdns.com."), DropAction())
-    addAction(AndRule({QTypeRule(dnsdist.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
-    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(dnsdist.REFUSED))
-    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(dnsdist.REFUSED))
+    addAction(AndRule({QTypeRule(DNSQType.A),QNameRule("ds9a.nl")}), SpoofAction("1.2.3.4"))
+    addAction(newDNSName("dnsname.addaction.powerdns.com."), RCodeAction(DNSRCode.REFUSED))
+    addAction({newDNSName("dnsname-table1.addaction.powerdns.com."), newDNSName("dnsname-table2.addaction.powerdns.com.")}, RCodeAction(DNSRCode.REFUSED))
     """
 
     def testDropped(self):
@@ -30,11 +30,10 @@ class TestBasics(DNSDistTest):
         """
         name = 'drop.test.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
     def testAWithECS(self):
         """
@@ -52,15 +51,12 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testSimpleA(self):
         """
@@ -76,19 +72,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAnyIsTruncated(self):
         """
@@ -214,11 +205,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testQNameReturnsSpoofed(self):
         """
@@ -238,12 +228,10 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.4')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainAndQTypeReturnsNotImplemented(self):
         """
@@ -259,11 +247,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainWithoutQTypeIsNotAffected(self):
         """
@@ -284,19 +271,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testOtherDomainANDQTypeIsNotAffected(self):
         """
@@ -317,19 +299,14 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testWrongResponse(self):
         """
@@ -353,17 +330,13 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         unrelatedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, unrelatedResponse)
+            self.assertTrue(receivedQuery)
+            self.assertEquals(receivedResponse, None)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
 
     def testHeaderOnlyRefused(self):
         """
@@ -375,17 +348,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.REFUSED)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testHeaderOnlyNoErrorResponse(self):
         """
@@ -396,17 +365,13 @@ class TestBasics(DNSDistTest):
         response = dns.message.make_response(query)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testHeaderOnlyNXDResponse(self):
         """
@@ -418,17 +383,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.NXDOMAIN)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testAddActionDNSName(self):
         """
@@ -439,8 +400,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAddActionDNSNames(self):
         """
@@ -451,9 +414,8 @@ class TestBasics(DNSDistTest):
             expectedResponse = dns.message.make_response(query)
             expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-            self.assertEquals(receivedResponse, expectedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (_, receivedResponse) = sender(query, response=None, useQueue=False)
+                self.assertEquals(receivedResponse, expectedResponse)
 
-if __name__ == '__main__':
-    unittest.main()
-    exit(0)
index a6b8aafa6dcec8f2c6962ede39f0f5b47354e0d9..589372b3e986245804a573a97d3c493ad990011f 100644 (file)
@@ -7,7 +7,7 @@ from dnsdisttests import DNSDistTest
 class TestCacheHitResponses(DNSDistTest):
 
     _config_template = """
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     addCacheHitResponseAction(makeRule("dropwhencached.cachehitresponses.tests.powerdns.com."), DropResponseAction())
     newServer{address="127.0.0.1:%s"}
index d50aa88694c2316745a5e69651e25974114d6404..aba9d554bd98397f14bf26e49addec30422ecbb0 100644 (file)
@@ -8,8 +8,7 @@ from dnsdisttests import DNSDistTest
 class TestCaching(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     addAction(makeRule("nocache.cache.tests.powerdns.com."), SkipCacheAction())
     function skipViaLua(dq)
@@ -155,19 +154,14 @@ class TestCaching(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
-
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                self.assertTrue(receivedQuery)
+                self.assertTrue(receivedResponse)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(receivedResponse, response)
 
         for key in self._responsesCounter:
             value = self._responsesCounter[key]
@@ -192,19 +186,14 @@ class TestCaching(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
-
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                self.assertTrue(receivedQuery)
+                self.assertTrue(receivedResponse)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(receivedResponse, response)
 
         for key in self._responsesCounter:
             value = self._responsesCounter[key]
@@ -417,8 +406,7 @@ class TestCaching(DNSDistTest):
 class TestTempFailureCacheTTLAction(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     addAction("servfail.cache.tests.powerdns.com.", TempFailureCacheTTLAction(1))
     newServer{address="127.0.0.1:%d"}
@@ -464,8 +452,7 @@ class TestTempFailureCacheTTLAction(DNSDistTest):
 class TestCachingWithExistingEDNS(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(5, 86400, 1)
+    pc = newPacketCache(5, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -522,8 +509,7 @@ class TestCachingWithExistingEDNS(DNSDistTest):
 class TestCachingCacheFull(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(1, 86400, 1)
+    pc = newPacketCache(1, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -597,8 +583,7 @@ class TestCachingNoStale(DNSDistTest):
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     setKey("%s")
     controlSocket("127.0.0.1:%d")
@@ -609,7 +594,7 @@ class TestCachingNoStale(DNSDistTest):
         Cache: Cache entry, set backend down, we should not get a stale entry
 
         """
-        ttl = 1
+        ttl = 2
         name = 'nostale.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -649,8 +634,7 @@ class TestCachingStale(DNSDistTest):
     _staleCacheTTL = 60
     _config_params = ['_staleCacheTTL', '_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=XX
-    pc = newPacketCache(100, 86400, 1, 0, %d)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=%d})
     getPool(""):setCache(pc)
     setStaleCacheEntriesTTL(600)
     setKey("%s")
@@ -663,7 +647,7 @@ class TestCachingStale(DNSDistTest):
 
         """
         misses = 0
-        ttl = 1
+        ttl = 2
         name = 'stale.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -711,8 +695,7 @@ class TestCachingStaleExpunged(DNSDistTest):
     _staleCacheTTL = 60
     _config_params = ['_staleCacheTTL', '_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=XX
-    pc = newPacketCache(100, 86400, 1, 0, %d)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=%d})
     getPool(""):setCache(pc)
     setStaleCacheEntriesTTL(600)
     -- try to remove all expired entries
@@ -729,7 +712,7 @@ class TestCachingStaleExpunged(DNSDistTest):
         """
         misses = 0
         drops = 0
-        ttl = 1
+        ttl = 2
         name = 'stale-but-expunged.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -786,8 +769,7 @@ class TestCachingStaleExpungePrevented(DNSDistTest):
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, ecsParsing=false, keepStaleData=true
-    pc = newPacketCache(100, 86400, 1, 0, 60, false, 1, true, 3600, false, { keepStaleData=true})
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=0, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, ecsParsing=false, keepStaleData=true})
     getPool(""):setCache(pc)
     setStaleCacheEntriesTTL(600)
     -- try to remove all expired entries
@@ -803,7 +785,7 @@ class TestCachingStaleExpungePrevented(DNSDistTest):
         Cache: Cache entry, set backend down, wait for the cache cleaning to run and remove the entry, still get a cache HIT because the stale entry was not removed
         """
         misses = 0
-        ttl = 1
+        ttl = 2
         name = 'stale-not-expunged.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -860,8 +842,7 @@ class TestCacheManagement(DNSDistTest):
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     setKey("%s")
     controlSocket("127.0.0.1:%d")
@@ -1057,7 +1038,7 @@ class TestCacheManagement(DNSDistTest):
         self.assertEquals(receivedResponse, response2)
 
         # remove cached entries from name A
-        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"" + name + "\"), dnsdist.A)")
+        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"" + name + "\"), DNSQType.A)")
 
         # Miss for name A
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
@@ -1138,7 +1119,7 @@ class TestCacheManagement(DNSDistTest):
         self.assertEquals(receivedResponse, response2)
 
         # remove cached entries from name
-        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"suffix.cache.tests.powerdns.com.\"), dnsdist.ANY, true)")
+        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"suffix.cache.tests.powerdns.com.\"), DNSQType.ANY, true)")
 
         # Miss for name
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
@@ -1219,7 +1200,7 @@ class TestCacheManagement(DNSDistTest):
         self.assertEquals(receivedResponse, response2)
 
         # remove cached entries from name A
-        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"suffixtype.cache.tests.powerdns.com.\"), dnsdist.A, true)")
+        self.sendConsoleCommand("getPool(\"\"):getCache():expungeByName(newDNSName(\"suffixtype.cache.tests.powerdns.com.\"), DNSQType.A, true)")
 
         # Miss for name A
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
@@ -1249,8 +1230,7 @@ class TestCachingTTL(DNSDistTest):
     _minCacheTTL = 600
     _config_params = ['_maxCacheTTL', '_minCacheTTL', '_testServerPort']
     _config_template = """
-    -- maxTTL=XX, minTTL=XX
-    pc = newPacketCache(1000, %d, %d)
+    pc = newPacketCache(1000, {maxTTL=%d, minTTL=%d})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1339,8 +1319,7 @@ class TestCachingLongTTL(DNSDistTest):
     _maxCacheTTL = 2
     _config_params = ['_maxCacheTTL', '_testServerPort']
     _config_template = """
-    -- maxTTL=XX
-    pc = newPacketCache(1000, %d)
+    pc = newPacketCache(1000, {maxTTL=%d})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1403,8 +1382,7 @@ class TestCachingFailureTTL(DNSDistTest):
     _failureCacheTTL = 2
     _config_params = ['_failureCacheTTL', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=0, temporaryFailureTTL=XX, staleTTL=60
-    pc = newPacketCache(1000, 86400, 0, %d, 60)
+    pc = newPacketCache(1000, {maxTTL=86400, minTTL=0, temporaryFailureTTL=%d, staleTTL=60})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1540,8 +1518,7 @@ class TestCachingNegativeTTL(DNSDistTest):
     _negCacheTTL = 1
     _config_params = ['_negCacheTTL', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=XX
-    pc = newPacketCache(1000, 86400, 0, 60, 60, false, 1, true, %d)
+    pc = newPacketCache(1000, {maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=%d})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1578,7 +1555,7 @@ class TestCachingNegativeTTL(DNSDistTest):
 
         time.sleep(self._negCacheTTL + 1)
 
-        # we should not have cached for longer than the negativel TTL
+        # we should not have cached for longer than the negative TTL
         # so it should be a miss
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
@@ -1645,8 +1622,7 @@ class TestCachingNegativeTTL(DNSDistTest):
 class TestCachingDontAge(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60, dontAge=true
-    pc = newPacketCache(100, 86400, 0, 60, 60, true)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=0, temporaryFailureTTL=60, staleTTL=60, dontAge=true})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1706,8 +1682,7 @@ class TestCachingECSWithoutPoolECS(DNSDistTest):
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     setKey("%s")
     controlSocket("127.0.0.1:%d")
@@ -1730,38 +1705,30 @@ class TestCachingECSWithoutPoolECS(DNSDistTest):
         response.answer.append(rrset)
 
         # first query to fill the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # next queries should hit the cache
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # over TCP too
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
         # we mark the backend as down
         self.sendConsoleCommand("getServer(0):setDown()")
 
         # we should NOT get a cached entry since it has ECS and we haven't asked the pool
         # to add ECS when no backend is up
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        # same over TCP
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestCachingECSWithPoolECS(DNSDistTest):
 
@@ -1769,8 +1736,7 @@ class TestCachingECSWithPoolECS(DNSDistTest):
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort']
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     getPool(""):setECS(true)
     setKey("%s")
@@ -1794,44 +1760,35 @@ class TestCachingECSWithPoolECS(DNSDistTest):
         response.answer.append(rrset)
 
         # first query to fill the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # next queries should hit the cache
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # over TCP too
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
         # we mark the backend as down
         self.sendConsoleCommand("getServer(0):setDown()")
 
         # we should STILL get a cached entry since it has ECS and we have asked the pool
         # to add ECS when no backend is up
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # same over TCP
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestCachingCollisionNoECSParsing(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1873,8 +1830,7 @@ class TestCachingCollisionNoECSParsing(DNSDistTest):
 class TestCachingCollisionWithECSParsing(DNSDistTest):
 
     _config_template = """
-    -- maxTTL=86400, minTTL=1, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, parseECS=true
-    pc = newPacketCache(100, 86400, 1, 60, 60, false, 1, true, 3600, true)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, parseECS=true})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d"}
     """
@@ -1923,7 +1879,7 @@ class TestCachingScopeZero(DNSDistTest):
 
     _config_template = """
     -- Be careful to enable ECS parsing in the packet cache, otherwise scope zero is disabled
-    pc = newPacketCache(100, 86400, 1, 60, 60, false, 1, true, 3600, true)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, parseECS=true})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d", useClientSubnet=true}
     -- to simulate a second client coming from a different IP address,
@@ -2111,7 +2067,7 @@ class TestCachingScopeZeroButNoSubnetcheck(DNSDistTest):
 
     _config_template = """
     -- We disable ECS parsing in the packet cache, meaning scope zero is disabled
-    pc = newPacketCache(100, 86400, 1, 60, 60, false, 1, true, 3600, false)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1, temporaryFailureTTL=60, staleTTL=60, dontAge=false, numberOfShards=1, deferrableInsertLock=true, maxNegativeTTL=3600, parseECS=false})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%d", useClientSubnet=true}
     -- to simulate a second client coming from a different IP address,
index 1971ddd9e1dd6e4015fa10d26bdd310c0fff1229..f075f4859e2dfc7f8c0524d69bf78a74e974a2db 100644 (file)
@@ -137,7 +137,7 @@ class TestCarbon(DNSDistTest):
         for line in data1.splitlines():
             if expectedStart in line:
                 parts = line.split(b' ')
-                if 'servers-up' in line:
+                if b'servers-up' in line:
                     self.assertEquals(len(parts), 3)
                     self.assertTrue(parts[1].isdigit())
                     self.assertEquals(int(parts[1]), 2)
@@ -160,7 +160,7 @@ class TestCarbon(DNSDistTest):
         for line in data2.splitlines():
             if expectedStart in line:
                 parts = line.split(b' ')
-                if 'servers-up' in line:
+                if b'servers-up' in line:
                     self.assertEquals(len(parts), 3)
                     self.assertTrue(parts[1].isdigit())
                     self.assertEquals(int(parts[1]), 2)
index ba3befb185721e858b7e9066075434d294849541..ce9345c05ed68d684a370432a925b3033617c0a1 100644 (file)
@@ -33,11 +33,11 @@ class TestCheckConfig(unittest.TestCase):
         configTemplate = """
             newServer{address="127.0.0.1:53"}
             truncateTC(true)
-            addAction(AndRule{QTypeRule(dnsdist.ANY), TCPRule(false)}, TCAction())
-            addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(dnsdist.REFUSED))
+            addAction(AndRule{QTypeRule(DNSQType.ANY), TCPRule(false)}, TCAction())
+            addAction(RegexRule("evil[0-9]{4,}\\\\.regex\\\\.tests\\\\.powerdns\\\\.com$"), RCodeAction(DNSRCode.REFUSED))
             mySMN = newSuffixMatchNode()
             mySMN:add(newDNSName("nameAndQtype.tests.powerdns.com."))
-            addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(dnsdist.NOTIMP))
+            addAction(AndRule{SuffixMatchNodeRule(mySMN), QTypeRule("TXT")}, RCodeAction(DNSRCode.NOTIMP))
             addAction(makeRule("drop.test.powerdns.com."), DropAction())
         """
 
index 08a021029fa3f654c488ff11d1d1ad6ac86e3c41..c1d59a4d5c812e5b4823ea78ca69887ae63dcdd0 100644 (file)
@@ -235,7 +235,7 @@ class TestDNSCryptWithCache(DNSCryptTest):
     _config_template = """
     generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
     addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
-    pc = newPacketCache(5, 86400, 1)
+    pc = newPacketCache(5, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
     newServer{address="127.0.0.1:%s"}
     """
index 78e71bffc89556d39f3b0e66e7188567929e1642..f91647e7b7c5ec0404422a5613a0a1d18c013d55 100644 (file)
@@ -102,7 +102,7 @@ class TestDnstapOverRemoteLogger(DNSDistTest):
 
     function luaFunc(dq)
       dq.dh:setQR(true)
-      dq.dh:setRCode(dnsdist.NXDOMAIN)
+      dq.dh:setRCode(DNSRCode.NXDOMAIN)
       return DNSAction.None, ""
     end
 
index 4a4259420decf6b2419cc3fe9b8c69e50b4dc60b..06260357451cb295abb7aa0b3843486b656b184d 100644 (file)
@@ -4,7 +4,11 @@ import json
 import requests
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+try:
+  range = xrange
+except NameError:
+  pass
 
 class DynBlocksTest(DNSDistTest):
 
@@ -799,7 +803,7 @@ class TestDynBlockGroupServFails(DynBlocksTest):
     _config_params = ['_dynBlockQPS', '_dynBlockPeriod', '_dynBlockDuration', '_testServerPort']
     _config_template = """
     local dbr = dynBlockRulesGroup()
-    dbr:setRCodeRate(dnsdist.SERVFAIL, %d, %d, "Exceeded query rate", %d)
+    dbr:setRCodeRate(DNSRCode.SERVFAIL, %d, %d, "Exceeded query rate", %d)
 
     function maintenance()
            dbr:apply()
index f5838bf40ce1efc50a58ff6352dcd06f7c76b372..a655fd5ba7f227c2d5271d96764c98a9ee875c2f 100644 (file)
@@ -23,17 +23,17 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.COOKIE]:count() ~= 2 then
           return DNSAction.Spoof, "192.0.2.2"
         end
-        if options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.3"
         end
-        if options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:getValues()[2]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.4"
         end
       elseif string.match(qname, 'cookie') then
         if options[EDNSOptionCode.COOKIE] == nil then
           return DNSAction.Spoof, "192.0.2.1"
         end
-        if options[EDNSOptionCode.COOKIE]:count() ~= 1 or options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:count() ~= 1 or options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.2"
         end
       end
@@ -42,7 +42,7 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.ECS] == nil then
           return DNSAction.Spoof, "192.0.2.51"
         end
-        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 8 then
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[1]:len() ~= 8 then
           return DNSAction.Spoof, "192.0.2.52"
         end
       end
@@ -51,7 +51,7 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.ECS] == nil then
           return DNSAction.Spoof, "192.0.2.101"
         end
-        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 20 then
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[1]:len() ~= 20 then
           return DNSAction.Spoof, "192.0.2.102"
         end
       end
@@ -66,7 +66,7 @@ class TestEDNSOptions(EDNSOptionsBase):
     _config_template = """
     %s
 
-    addLuaAction(AllRule(), testEDNSOptions)
+    addAction(AllRule(), LuaAction(testEDNSOptions))
 
     newServer{address="127.0.0.1:%s"}
     """
@@ -86,26 +86,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '192.0.2.255')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testCookie(self):
         """
         EDNS Options: Cookie
         """
         name = 'cookie.ednsoptions.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -115,19 +110,14 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS4(self):
         """
@@ -144,19 +134,14 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS6(self):
         """
@@ -173,26 +158,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS6Cookie(self):
         """
         EDNS Options: Cookie + ECS6
         """
         name = 'cookie-ecs6.ednsoptions.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
         response = dns.message.make_response(query)
@@ -203,28 +183,23 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testMultiCookiesECS6(self):
         """
         EDNS Options: Two Cookies + ECS6
         """
         name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
-        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
-        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -234,26 +209,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
 class TestEDNSOptionsAddingECS(EDNSOptionsBase):
 
     _config_template = """
     %s
 
-    addLuaAction(AllRule(), testEDNSOptions)
+    addAction(AllRule(), LuaAction(testEDNSOptions))
 
     newServer{address="127.0.0.1:%s", useClientSubnet=true}
     """
@@ -275,26 +245,21 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
     def testCookie(self):
         """
         EDNS Options: Cookie (adding ECS)
         """
         name = 'cookie.ednsoptions-ecs.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[eco,ecso], payload=512)
@@ -306,19 +271,14 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+            self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
     def testECS4(self):
         """
@@ -337,19 +297,14 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testECS6(self):
         """
@@ -368,26 +323,21 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testECS6Cookie(self):
         """
         EDNS Options: Cookie + ECS6 (adding ECS)
         """
         name = 'cookie-ecs6.ednsoptions-ecs.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
         ecsoResponse = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128, scope=56)
@@ -400,28 +350,23 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery, 1)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testMultiCookiesECS6(self):
         """
         EDNS Options: Two Cookies + ECS6
         """
         name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
-        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
-        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -431,16 +376,11 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
index 7704641319a7e08531bb775f5896d105b50dd05b..c5cb99bd7bd8a115d3d3a99dc3b25927483fe031 100644 (file)
@@ -11,14 +11,14 @@ class TestEDNSSelfGenerated(DNSDistTest):
     """
 
     _config_template = """
-    addAction("rcode.edns-self.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
+    addAction("rcode.edns-self.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED))
     addAction("tc.edns-self.tests.powerdns.com.", TCAction())
 
     function luarule(dq)
       return DNSAction.Nxdomain, ""
     end
 
-    addLuaAction("lua.edns-self.tests.powerdns.com.", luarule)
+    addAction("lua.edns-self.tests.powerdns.com.", LuaAction(luarule))
 
     addAction("spoof.edns-self.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
 
@@ -36,33 +36,30 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
@@ -75,11 +72,10 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoDO(self):
         """
@@ -90,45 +86,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
@@ -141,15 +128,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
     def testWithEDNSWithDO(self):
         """
@@ -160,45 +144,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
@@ -211,15 +186,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
     def testWithEDNSNoOptions(self):
         """
@@ -231,45 +203,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
@@ -282,15 +245,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
 
 class TestEDNSSelfGeneratedDisabled(DNSDistTest):
@@ -302,14 +262,14 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
     _config_template = """
     setAddEDNSToSelfGeneratedResponses(false)
 
-    addAction("rcode.edns-self-disabled.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
+    addAction("rcode.edns-self-disabled.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED))
     addAction("tc.edns-self-disabled.tests.powerdns.com.", TCAction())
 
     function luarule(dq)
       return DNSAction.Nxdomain, ""
     end
 
-    addLuaAction("lua.edns-self-disabled.tests.powerdns.com.", luarule)
+    addAction("lua.edns-self-disabled.tests.powerdns.com.", LuaAction(luarule))
 
     addAction("spoof.edns-self-disabled.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
 
@@ -327,33 +287,30 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.tc.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.lua.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.spoof.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
@@ -366,8 +323,7 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
index 6adb863ed012c16a06ec8a2cfcdf4ad1945652f5..87acf47da254373bff73743a83ef0bba60036d46 100644 (file)
@@ -40,19 +40,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -77,19 +72,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSECS(self):
         """
@@ -113,19 +103,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
 
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
     def testWithoutEDNSResponseWithECS(self):
         """
@@ -154,19 +139,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithECS(self):
         """
@@ -195,19 +175,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithCookiesThenECS(self):
         """
@@ -238,19 +213,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         expectedResponse.answer.append(rrset)
         expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithECSThenCookies(self):
         """
@@ -281,19 +251,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         expectedResponse.answer.append(rrset)
         response.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithCookiesThenECSThenCookies(self):
         """
@@ -323,20 +288,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
 
 class TestEdnsClientSubnetOverride(DNSDistTest):
     """
@@ -376,19 +335,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -414,19 +368,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSShorterInitialECS(self):
         """
@@ -454,19 +403,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSLongerInitialECS(self):
         """
@@ -494,19 +438,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSSameSizeInitialECS(self):
         """
@@ -534,19 +473,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
 class TestECSDisabledByRuleOrLua(DNSDistTest):
     """
@@ -586,19 +520,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSDisabledViaRule(self):
         """
@@ -614,19 +543,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryNoEDNS(query, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
     def testWithECSDisabledViaLua(self):
         """
@@ -642,19 +566,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryNoEDNS(query, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
 class TestECSOverrideSetByRuleOrLua(DNSDistTest):
     """
@@ -692,19 +611,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithECSOverrideSetViaRule(self):
         """
@@ -724,19 +638,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithECSOverrideSetViaLua(self):
         """
@@ -756,19 +665,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
 class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
     """
@@ -809,19 +713,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSPrefixLengthOverriddenViaRule(self):
         """
@@ -841,19 +740,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSPrefixLengthOverriddenViaLua(self):
         """
@@ -873,19 +767,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
 class TestECSPrefixSetByRule(DNSDistTest):
     """
@@ -921,19 +810,14 @@ class TestECSPrefixSetByRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSSetByRule(self):
         """
@@ -953,16 +837,11 @@ class TestECSPrefixSetByRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
index 8e25488c07ebc84b6e77d1bacc3967889368cfce..efccf206ff7af8591c91096e6ecb4a691312b1c2 100644 (file)
@@ -163,7 +163,7 @@ class TestHealthCheckCustomFunction(HealthCheckTest):
     function myHealthCheckFunction(qname, qtype, qclass, dh)
       dh:setCD(true)
 
-      return newDNSName('powerdns.com.'), dnsdist.AAAA, qclass
+      return newDNSName('powerdns.com.'), DNSQType.AAAA, qclass
     end
 
     srv = newServer{address="127.0.0.1:%d", checkName='powerdns.org.', checkFunction=myHealthCheckFunction}
index cdee353ee10d7c110326936c6b01e32558012d7f..f8345d2092cb70f191758823f8484767c735d41e 100644 (file)
@@ -9,119 +9,11 @@ from dnsdisttests import DNSDistTest, Queue
 import dns
 import dnsmessage_pb2
 
-
-class TestProtobuf(DNSDistTest):
+class DNSDistProtobufTest(DNSDistTest):
     _protobufServerPort = 4242
     _protobufQueue = Queue()
     _protobufServerID = 'dnsdist-server-1'
     _protobufCounter = 0
-    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
-    _config_template = """
-    luasmn = newSuffixMatchNode()
-    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
-
-    function alterProtobufResponse(dq, protobuf)
-      if luasmn:check(dq.qname) then
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then
-          requestor:truncate(24)
-        else
-          requestor:truncate(56)
-        end
-        protobuf:setRequestor(requestor)
-
-        local tableTags = {}
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-
-        protobuf:setTagArray(tableTags)
-
-        protobuf:setTag('TestLabel3,TestData3')
-
-        protobuf:setTag("Response,456")
-
-      else
-
-        local tableTags = {}                                   -- called by testProtobuf()
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-        protobuf:setTagArray(tableTags)
-
-        protobuf:setTag('TestLabel3,TestData3')
-
-        protobuf:setTag("Response,456")
-
-      end
-    end
-
-    function alterProtobufQuery(dq, protobuf)
-
-      if luasmn:check(dq.qname) then
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then
-          requestor:truncate(24)
-        else
-          requestor:truncate(56)
-        end
-        protobuf:setRequestor(requestor)
-
-        local tableTags = {}
-        tableTags = dq:getTagArray()                           -- get table from DNSQuery
-
-        local tablePB = {}
-          for k, v in pairs( tableTags) do
-          table.insert(tablePB, k .. "," .. v)
-        end
-
-        protobuf:setTagArray(tablePB)                          -- store table in protobuf
-        protobuf:setTag("Query,123")                           -- add another tag entry in protobuf
-
-        protobuf:setResponseCode(dnsdist.NXDOMAIN)             -- set protobuf response code to be NXDOMAIN
-
-        local strReqName = dq.qname:toString()                 -- get request dns name
-
-        protobuf:setProtobufResponseType()                     -- set protobuf to look like a response and not a query, with 0 default time
-
-        blobData = '\127' .. '\000' .. '\000' .. '\001'                -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
-
-        protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
-
-        protobuf:setBytes(65)                                  -- set the size of the query to confirm in checkProtobufBase
-
-      else
-
-        local tableTags = {}                                    -- called by testProtobuf()
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-
-        protobuf:setTagArray(tableTags)
-        protobuf:setTag('TestLabel3,TestData3')
-        protobuf:setTag("Query,123")
-
-      end
-    end
-
-    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
-      local tt = {}
-      tt["TestLabel1"] = "TestData1"
-      tt["TestLabel2"] = "TestData2"
-
-      dq:setTagArray(tt)
-
-      dq:setTag("TestLabel3","TestData3")
-      return DNSAction.None, ""                                -- continue to the next rule
-    end
-
-    newServer{address="127.0.0.1:%s", useClientSubnet=true}
-    rl = newRemoteLogger('127.0.0.1:%s')
-
-    addAction(AllRule(), LuaAction(alterLuaFirst))                                                     -- Add tags to DNSQuery first
-
-    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'}))                             -- Send protobuf message before lookup
-
-    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'}))    -- Send protobuf message after lookup
-
-    """
 
     @classmethod
     def ProtobufListener(cls, port):
@@ -188,7 +80,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(msg.id, query.id)
         self.assertTrue(msg.HasField('inBytes'))
         self.assertTrue(msg.HasField('serverIdentity'))
-        self.assertEquals(msg.serverIdentity, self._protobufServerID)
+        self.assertEquals(msg.serverIdentity, self._protobufServerID.encode('utf-8'))
 
         if normalQueryResponse:
           # compare inBytes with length of query/response
@@ -244,6 +136,115 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(record.ttl, rttl)
         self.assertTrue(record.HasField('rdata'))
 
+class TestProtobuf(DNSDistProtobufTest):
+    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
+    _config_template = """
+    luasmn = newSuffixMatchNode()
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+    function alterProtobufResponse(dq, protobuf)
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then
+          requestor:truncate(24)
+        else
+          requestor:truncate(56)
+        end
+        protobuf:setRequestor(requestor)
+
+        local tableTags = {}
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+
+        protobuf:setTagArray(tableTags)
+
+        protobuf:setTag('TestLabel3,TestData3')
+
+        protobuf:setTag("Response,456")
+
+      else
+
+        local tableTags = {}                                   -- called by testProtobuf()
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+        protobuf:setTagArray(tableTags)
+
+        protobuf:setTag('TestLabel3,TestData3')
+
+        protobuf:setTag("Response,456")
+
+      end
+    end
+
+    function alterProtobufQuery(dq, protobuf)
+
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then
+          requestor:truncate(24)
+        else
+          requestor:truncate(56)
+        end
+        protobuf:setRequestor(requestor)
+
+        local tableTags = {}
+        tableTags = dq:getTagArray()                           -- get table from DNSQuery
+
+        local tablePB = {}
+          for k, v in pairs( tableTags) do
+          table.insert(tablePB, k .. "," .. v)
+        end
+
+        protobuf:setTagArray(tablePB)                          -- store table in protobuf
+        protobuf:setTag("Query,123")                           -- add another tag entry in protobuf
+
+        protobuf:setResponseCode(DNSRCode.NXDOMAIN)            -- set protobuf response code to be NXDOMAIN
+
+        local strReqName = dq.qname:toString()                 -- get request dns name
+
+        protobuf:setProtobufResponseType()                     -- set protobuf to look like a response and not a query, with 0 default time
+
+        blobData = '\127' .. '\000' .. '\000' .. '\001'                -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
+
+        protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
+
+        protobuf:setBytes(65)                                  -- set the size of the query to confirm in checkProtobufBase
+
+      else
+
+        local tableTags = {}                                    -- called by testProtobuf()
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+
+        protobuf:setTagArray(tableTags)
+        protobuf:setTag('TestLabel3,TestData3')
+        protobuf:setTag("Query,123")
+
+      end
+    end
+
+    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
+      local tt = {}
+      tt["TestLabel1"] = "TestData1"
+      tt["TestLabel2"] = "TestData2"
+
+      dq:setTagArray(tt)
+
+      dq:setTag("TestLabel3","TestData3")
+      return DNSAction.None, ""                                -- continue to the next rule
+    end
+
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    rl = newRemoteLogger('127.0.0.1:%s')
+
+    addAction(AllRule(), LuaAction(alterLuaFirst))                                                     -- Add tags to DNSQuery first
+
+    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'}))                             -- Send protobuf message before lookup
+
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'}))    -- Send protobuf message after lookup
+
+    """
+
     def testProtobuf(self):
         """
         Protobuf: Send data to a protobuf server
@@ -291,7 +292,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
+        self.assertEquals(rr.rdata.decode('utf-8'), target)
         rr = msg.response.rrs[1]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
@@ -319,7 +320,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
+        self.assertEquals(rr.rdata.decode('utf-8'), target)
         rr = msg.response.rrs[1]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
@@ -390,3 +391,92 @@ class TestProtobuf(DNSDistTest):
         for rr in msg.response.rrs:
             self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
             self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+class TestProtobufIPCipher(DNSDistProtobufTest):
+    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    key = makeIPCipherKey("some 16-byte key")
+    rl = newRemoteLogger('127.0.0.1:%s')
+    addAction(AllRule(), RemoteLogAction(rl, nil, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message before lookup
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, nil, true, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message after lookup
+
+    """
+
+    def testProtobuf(self):
+        """
+        Protobuf: Send data to a protobuf server
+        """
+        name = 'query.protobuf-ipcipher.tests.powerdns.com.'
+
+        target = b'target.protobuf-ipcipher.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    target)
+        response.answer.append(rrset)
+
+        rrset = dns.rrset.from_text(target,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+
+        # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '108.41.239.98')
+
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '108.41.239.98')
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
index cfd3d614bc067e2ffd2f883d62c5420a60f8f089..1abe575fd0991168e83c3ccbcc18d26311b05d9f 100644 (file)
@@ -7,7 +7,7 @@ from dnsdisttests import DNSDistTest
 class TestRecordsCountOnlyOneAR(DNSDistTest):
 
     _config_template = """
-    addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(dnsdist.REFUSED))
+    addAction(NotRule(RecordsCountRule(DNSSection.Additional, 1, 1)), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -23,11 +23,10 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowOneAR(self):
         """
@@ -45,19 +44,14 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseTwoAR(self):
         """
@@ -76,17 +70,16 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
 
     _config_template = """
     addAction(RecordsCountRule(DNSSection.Answer, 2, 3), AllowAction())
-    addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -102,11 +95,10 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowTwoAN(self):
         """
@@ -126,19 +118,14 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         response = dns.message.make_response(query)
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseFourAN(self):
         """
@@ -160,17 +147,16 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountNothingInNS(DNSDistTest):
 
     _config_template = """
     addAction(RecordsCountRule(DNSSection.Authority, 0, 0), AllowAction())
-    addAction(AllRule(), RCodeAction(dnsdist.REFUSED))
+    addAction(AllRule(), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -193,11 +179,10 @@ class TestRecordsCountNothingInNS(DNSDistTest):
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.authority.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 
     def testRecordsCountAllowEmptyNS(self):
@@ -216,24 +201,19 @@ class TestRecordsCountNothingInNS(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestRecordsCountNoOPTInAR(DNSDistTest):
 
     _config_template = """
-    addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, dnsdist.OPT, 0, 0)), RCodeAction(dnsdist.REFUSED))
+    addAction(NotRule(RecordsTypeCountRule(DNSSection.Additional, DNSQType.OPT, 0, 0)), RCodeAction(DNSRCode.REFUSED))
     newServer{address="127.0.0.1:%s"}
     """
 
@@ -249,11 +229,10 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowNoOPTInAR(self):
         """
@@ -271,19 +250,14 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountAllowTwoARButNoOPT(self):
         """
@@ -312,16 +286,11 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
index 494143c6d657d8420805b574bdc31b2e50c30ef3..fb9276493fd647e014ae787771d059f48a05d127 100644 (file)
@@ -8,7 +8,7 @@ class TestResponseRuleNXDelayed(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addResponseAction(RCodeRule(dnsdist.NXDOMAIN), DelayResponseAction(1000))
+    addResponseAction(RCodeRule(DNSRCode.NXDOMAIN), DelayResponseAction(1000))
     """
 
     def testNXDelayed(self):
@@ -57,7 +57,7 @@ class TestResponseRuleERCode(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addResponseAction(ERCodeRule(dnsdist.BADVERS), DelayResponseAction(1000))
+    addResponseAction(ERCodeRule(DNSRCode.BADVERS), DelayResponseAction(1000))
     """
 
     def testBADVERSDelayed(self):
@@ -122,15 +122,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testNotDropped(self):
         """
@@ -143,15 +140,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestResponseRuleQNameAllowed(DNSDistTest):
 
@@ -172,15 +166,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testNotAllowed(self):
         """
@@ -193,15 +184,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
 class TestResponseRuleEditTTL(DNSDistTest):
 
@@ -236,19 +224,14 @@ class TestResponseRuleEditTTL(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+            self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
 
 class TestResponseLuaActionReturnSyntax(DNSDistTest):
 
@@ -261,7 +244,7 @@ class TestResponseLuaActionReturnSyntax(DNSDistTest):
       return DNSResponseAction.Drop
     end
     addResponseAction("drop.responses.tests.powerdns.com.", LuaResponseAction(customDrop))
-    addResponseAction(RCodeRule(dnsdist.NXDOMAIN), LuaResponseAction(customDelay))
+    addResponseAction(RCodeRule(DNSRCode.NXDOMAIN), LuaResponseAction(customDelay))
     """
 
     def testResponseActionDelayed(self):
@@ -297,12 +280,9 @@ class TestResponseLuaActionReturnSyntax(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
index dfb8c1da6b7d639eac1808c227aade1589bbb1eb..386fba655f91ce2d8235e620cae1ac5283233e40 100644 (file)
@@ -29,15 +29,12 @@ class TestRoutingPoolRouting(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testDefaultPool(self):
         """
@@ -50,11 +47,10 @@ class TestRoutingPoolRouting(DNSDistTest):
         name = 'notpool.routing.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestRoutingQPSPoolRouting(DNSDistTest):
     _config_template = """
@@ -224,6 +220,39 @@ class TestRoutingRoundRobinLBOneDown(DNSDistTest):
 
         self.assertEquals(total, numberOfQueries * 2)
 
+class TestRoutingRoundRobinLBAllDown(DNSDistTest):
+
+    _testServer2Port = 5351
+    _config_params = ['_testServerPort', '_testServer2Port']
+    _config_template = """
+    setServerPolicy(roundrobin)
+    setRoundRobinFailOnNoServer(true)
+    s1 = newServer{address="127.0.0.1:%s"}
+    s1:setDown()
+    s2 = newServer{address="127.0.0.1:%s"}
+    s2:setDown()
+    """
+
+    def testRRWithAllDown(self):
+        """
+        Routing: Round Robin with all servers down
+        """
+        numberOfQueries = 10
+        name = 'alldown.rr.routing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
+
 class TestRoutingOrder(DNSDistTest):
 
     _testServer2Port = 5351
@@ -274,14 +303,12 @@ class TestRoutingOrder(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(response, receivedResponse)
 
         total = 0
         if 'UDP Responder' in self._responsesCounter:
@@ -307,12 +334,10 @@ class TestRoutingNoServer(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.SERVFAIL)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRoutingWRandom(DNSDistTest):
 
index 601167e824572eb550cc030b1657bd8cf3be86a0..4cedcd32bf59294dfdc00fea6839cc146d168946 100644 (file)
@@ -30,13 +30,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionAAAA(self):
         """
@@ -57,13 +55,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionCNAME(self):
         """
@@ -84,13 +80,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     'cnameaction.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiA(self):
         """
@@ -111,13 +105,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '192.0.2.2', '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiAAAA(self):
         """
@@ -138,13 +130,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1', '2001:DB8::2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiANY(self):
         """
@@ -173,13 +163,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1', '2001:DB8::2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestSpoofingLuaSpoof(DNSDistTest):
 
@@ -222,13 +210,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     '192.0.2.1', '192.0.2.2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAAAA(self):
         """
@@ -249,13 +235,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     '2001:DB8::1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAWithCNAME(self):
         """
@@ -276,13 +260,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     'spoofedcname.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAAAAWithCNAME(self):
         """
@@ -303,13 +285,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     'spoofedcname.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestSpoofingLuaWithStatistics(DNSDistTest):
 
@@ -367,10 +347,8 @@ class TestSpoofingLuaWithStatistics(DNSDistTest):
         self.assertTrue(receivedResponse)
         self.assertEquals(expectedResponse2, receivedResponse)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponseAfterwards, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponseAfterwards, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponseAfterwards, receivedResponse)
index a1672e494128142454619d6676942316fdaa91a6..f975c9f403a03d30c6d1481ba207c8d857b86567 100644 (file)
@@ -2,7 +2,12 @@
 import struct
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+
+try:
+    range = xrange
+except NameError:
+    pass
 
 class TestTCPKeepAlive(DNSDistTest):
     """
@@ -13,7 +18,7 @@ class TestTCPKeepAlive(DNSDistTest):
 
     _tcpIdleTimeout = 20
     _maxTCPQueriesPerConn = 99
-    _maxTCPConnsPerClient = 3
+    _maxTCPConnsPerClient = 100
     _maxTCPConnDuration = 99
     _config_template = """
     newServer{address="127.0.0.1:%s"}
@@ -21,9 +26,10 @@ class TestTCPKeepAlive(DNSDistTest):
     setMaxTCPQueriesPerConnection(%s)
     setMaxTCPConnectionsPerClient(%s)
     setMaxTCPConnectionDuration(%s)
-    pc = newPacketCache(100, 86400, 1)
+    pc = newPacketCache(100, {maxTTL=86400, minTTL=1})
     getPool(""):setCache(pc)
-    addAction("refused.tcpka.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
+    addAction("largernumberofconnections.tcpka.tests.powerdns.com.", SkipCacheAction())
+    addAction("refused.tcpka.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED))
     addAction("dropped.tcpka.tests.powerdns.com.", DropAction())
     addResponseAction("dropped-response.tcpka.tests.powerdns.com.", DropResponseAction())
     -- create the pool named "nosuchpool"
@@ -199,6 +205,49 @@ class TestTCPKeepAlive(DNSDistTest):
         conn.close()
         self.assertEqual(count, 0)
 
+    def testTCPKaLargeNumberOfConnections(self):
+        """
+        TCP KeepAlive: Large number of connections
+        """
+        name = 'largernumberofconnections.tcpka.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        #expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        # number of connections
+        numConns = 50
+        # number of queries per connections
+        numQueriesPerConn = 4
+
+        conns = []
+        start = time.time()
+        for idx in range(numConns):
+            conns.append(self.openTCPConnection())
+
+        count = 0
+        for idx in range(numConns * numQueriesPerConn):
+            try:
+                conn = conns[idx % numConns]
+                self.sendTCPQueryOverConnection(conn, query, response=expectedResponse)
+                response = self.recvTCPResponseOverConnection(conn)
+                if response is None:
+                    break
+                self.assertEquals(expectedResponse, response)
+                count = count + 1
+            except:
+                pass
+
+        for con in conns:
+          conn.close()
+
+        self.assertEqual(count, numConns * numQueriesPerConn)
+
 class TestTCPKeepAliveNoDownstreamDrop(DNSDistTest):
     """
     This test makes sure that dnsdist drops the TCP connection
index e67a152d375db844a39c2a02a765ba6745169a7b..ec20c929f5966591f857c55e9e0d9a2e52730f37 100644 (file)
@@ -2,7 +2,12 @@
 import struct
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+
+try:
+  range = xrange
+except NameError:
+  pass
 
 class TestTCPLimits(DNSDistTest):
 
@@ -101,19 +106,20 @@ class TestTCPLimits(DNSDistTest):
         conn.send(struct.pack("!H", 65535))
 
         count = 0
-        while count < (self._maxTCPConnDuration * 2):
+        while count < (self._maxTCPConnDuration * 20):
             try:
                 # sleeping for only one second keeps us below the
                 # idle timeout (setTCPRecvTimeout())
-                time.sleep(1)
-                conn.send('A')
+                time.sleep(0.1)
+                conn.send(b'A')
                 count = count + 1
-            except:
+            except Exception as e:
+                print("Exception: %s!" % (e))
                 break
 
         end = time.time()
 
-        self.assertAlmostEquals(count, self._maxTCPConnDuration, delta=2)
+        self.assertAlmostEquals(count / 10, self._maxTCPConnDuration, delta=2)
         self.assertAlmostEquals(end - start, self._maxTCPConnDuration, delta=2)
 
         conn.close()
index 68aaa83e6d18562556964e00ff2047f7f63e407d..b31c6c2f9bfcf4933bdc99103b5eb2febbf2a035 100644 (file)
@@ -38,3 +38,55 @@ class TestTLS(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
+
+    def testTLKA(self):
+        """
+        TLS: Several queries over the same connection
+        """
+        name = 'ka.tls.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+        for idx in range(5):
+            self.sendTCPQueryOverConnection(conn, query, response=response)
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+
+    def testTLSPipelining(self):
+        """
+        TLS: Several queries over the same connection without waiting for the responses
+        """
+        name = 'pipelining.tls.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+        for idx in range(100):
+            self.sendTCPQueryOverConnection(conn, query, response=response)
+
+        for idx in range(100):
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
index 6da32b9a0f14f3ce18d7751a5ca6463b2fc84fa8..736c87bf950006f0200a5ebcb1888a11cd60bee9 100644 (file)
@@ -55,19 +55,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testQuestionMatchTagAndValue(self):
         """
@@ -85,13 +80,11 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.50')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testQuestionMatchTagOnly(self):
         """
@@ -109,13 +102,11 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.100')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testResponseNoMatch(self):
         """
@@ -131,19 +122,14 @@ class TestBasics(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testResponseMatchTagAndValue(self):
         """
@@ -165,21 +151,14 @@ class TestBasics(DNSDistTest):
         # we will set TC if the tag matches
         expectedResponse.flags |= dns.flags.TC
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        print(expectedResponse)
-        print(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testResponseMatchResponseTagMatches(self):
         """
@@ -201,16 +180,11 @@ class TestBasics(DNSDistTest):
         # we will set QR=0 if the tag matches
         expectedResponse.flags &= ~dns.flags.QR
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
index 5adb3819be470f689969204349b2200f63b5f3ca..156dcda7be2965e47e10679538af5907c624fe83 100644 (file)
@@ -16,8 +16,8 @@ class TestTeeAction(DNSDistTest):
     setKey("%s")
     controlSocket("127.0.0.1:%s")
     newServer{address="127.0.0.1:%d"}
-    addAction(QTypeRule(dnsdist.A), TeeAction("127.0.0.1:%d", true))
-    addAction(QTypeRule(dnsdist.AAAA), TeeAction("127.0.0.1:%d", false))
+    addAction(QTypeRule(DNSQType.A), TeeAction("127.0.0.1:%d", true))
+    addAction(QTypeRule(DNSQType.AAAA), TeeAction("127.0.0.1:%d", false))
     """
     _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort', '_teeServerPort']
     @classmethod
index a9c74b58556bdeaf7f3e38eb9048f33c031e3ee7..b87b519f7ad7b9fe11f42ee82ec84383ab497cb8 100644 (file)
@@ -20,7 +20,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("added.trailing.tests.powerdns.com.", replaceTrailingData)
+    addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
 
     function fillBuffer(dq)
         local available = dq.size - dq.len
@@ -31,7 +31,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("max.trailing.tests.powerdns.com.", fillBuffer)
+    addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
 
     function exceedBuffer(dq)
         local available = dq.size - dq.len
@@ -42,7 +42,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
+    addAction("limited.trailing.tests.powerdns.com.", LuaAction(exceedBuffer))
     """
     @classmethod
     def startResponders(cls):
@@ -80,8 +80,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -108,8 +106,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -136,8 +132,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (_, receivedResponse) = sender(query, response)
             self.assertTrue(receivedResponse)
             self.assertEquals(receivedResponse, expectedResponse)
@@ -161,8 +155,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -183,13 +175,13 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
+    addAction("removed.trailing.tests.powerdns.com.", LuaAction(removeTrailingData))
 
     function reportTrailingData(dq)
         local tail = dq:getTrailingData()
         return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
     end
-    addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("echoed.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
 
     function replaceTrailingData(dq)
         local success = dq:setTrailingData("ABC")
@@ -198,8 +190,8 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("replaced.trailing.tests.powerdns.com.", replaceTrailingData)
-    addLuaAction("replaced.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
 
     function reportTrailingHex(dq)
         local tail = dq:getTrailingData()
@@ -208,7 +200,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end)
         return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
     end
-    addLuaAction("echoed-hex.trailing.tests.powerdns.com.", reportTrailingHex)
+    addAction("echoed-hex.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
 
     function replaceTrailingData_unsafe(dq)
         local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
@@ -217,8 +209,8 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", replaceTrailingData_unsafe)
-    addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", reportTrailingHex)
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData_unsafe))
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
     """
 
     def testTrailingDropped(self):
@@ -243,8 +235,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             sender = getattr(self, method)
 
             # Verify that queries with no trailing data make it through.
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -253,8 +243,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertEquals(response, receivedResponse)
 
             # Verify that queries with trailing data don't make it through.
-            # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertEquals(receivedResponse, None)
 
@@ -278,8 +266,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -309,8 +295,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -338,8 +322,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -367,8 +349,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -396,8 +376,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
index 7da9f05040c362b54537029f5c770562268bce15..26038091e8e18cb70e872ae4458a585266a75412 100644 (file)
@@ -184,6 +184,8 @@ class testECSByNameLarger(ECSTest):
     _config_template = """edns-subnet-whitelist=ecs-echo.example.
 ecs-ipv4-bits=32
 forward-zones=ecs-echo.example=%s.21
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -244,6 +246,8 @@ class testIncomingECSByName(ECSTest):
 use-incoming-edns-subnet=yes
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=2001:db8::42
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -283,6 +287,8 @@ use-incoming-edns-subnet=yes
 ecs-ipv4-bits=32
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=192.168.0.1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -313,6 +319,8 @@ use-incoming-edns-subnet=yes
 ecs-ipv4-bits=16
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=192.168.0.1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -339,6 +347,8 @@ class testIncomingECSByNameV6(ECSTest):
     _config_template = """edns-subnet-whitelist=ecs-echo.example.
 use-incoming-edns-subnet=yes
 ecs-ipv6-bits=128
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
 forward-zones=ecs-echo.example=%s.21
 query-local-address6=::1
     """ % (os.environ['PREFIX'])
@@ -423,6 +433,8 @@ class testIncomingECSByIP(ECSTest):
 use-incoming-edns-subnet=yes
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=::1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'], os.environ['PREFIX'])
 
     def testSendECS(self):
index 7ef99f31d482be2204a53ce3ac2c1110eb0b1044..992a8bcf4a83deecef0d2c3c79ad202b430ef9da 100644 (file)
@@ -26,15 +26,14 @@ __EOF__
 
        $RUNWRAPPER $PDNS2 --daemon=no --local-port=$port --config-dir=. \
                --config-name=lmdb2 --socket-dir=./ --no-shuffle \
-               --slave --retrieval-threads=4 \
-               --slave-cycle-interval=300 --dname-processing &
+               --slave --dname-processing --api --api-key=secret &
 
        echo 'waiting for zones to be slaved'
        loopcount=0
        while [ $loopcount -lt 30 ]
        do
                sleep 5
-               present=$($PDNSUTIL --config-dir=. --config-name=lmdb2 list-all-zones slave | wc -l)
+               present=$(curl -s -S -H 'X-API-Key: secret' http://127.0.0.1:8081/api/v1/servers/localhost/zones  | jq -r '.[] | .serial' | grep -c -v '^0$')
                if [ $zones -eq $present ]
                then
                        break