]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #7643 from mind04/lmdb-fix
authorbert hubert <bert.hubert@netherlabs.nl>
Tue, 9 Apr 2019 15:22:09 +0000 (17:22 +0200)
committerGitHub <noreply@github.com>
Tue, 9 Apr 2019 15:22:09 +0000 (17:22 +0200)
auth: lmdbbackend, fix getAllDomains()

167 files changed:
.circleci/config.yml [new file with mode: 0644]
README.md
build-scripts/circleci.sh [new file with mode: 0755]
builder-support/dockerfiles/Dockerfile.rpmbuild
configure.ac
docs/backends/generic-mysql.rst
docs/backends/generic-postgresql.rst
docs/changelog/4.1.rst
docs/changelog/4.2.rst
docs/common/api/endpoint-statistics.rst
docs/dnssec/modes-of-operation.rst
docs/http-api/swagger/authoritative-api-swagger.yaml
docs/manpages/dnswasher.1.rst
docs/manpages/ixfrdist.yml.5.rst
docs/manpages/pdnsutil.1.rst
docs/secpoll.zone
docs/settings.rst
ext/Makefile.am
ext/ipcrypt/.gitignore [new file with mode: 0644]
ext/ipcrypt/LICENSE [new file with mode: 0644]
ext/ipcrypt/Makefile.am [new file with mode: 0644]
ext/ipcrypt/ipcrypt.c [new file with mode: 0644]
ext/ipcrypt/ipcrypt.h [new file with mode: 0644]
m4/pdns_check_libcrypto.m4
modules/gmysqlbackend/gmysqlbackend.cc
modules/gmysqlbackend/smysql.cc
modules/gmysqlbackend/smysql.hh
modules/lmdbbackend/lmdb-typed.hh
modules/randombackend/randombackend.cc
pdns/Makefile.am
pdns/README-dnsdist.md
pdns/backends/gsql/gsqlbackend.cc
pdns/calidns.cc
pdns/common_startup.cc
pdns/dbdnsseckeeper.cc
pdns/devpollmplexer.cc
pdns/distributor.hh
pdns/dnscrypt.hh
pdns/dnsdist-cache.cc
pdns/dnsdist-cache.hh
pdns/dnsdist-carbon.cc
pdns/dnsdist-console.cc
pdns/dnsdist-dynblocks.hh
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua-vars.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist-web.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/configure.ac
pdns/dnsdistdist/dnsdist-idstate.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/advanced/luaaction.rst
pdns/dnsdistdist/docs/advanced/timedipsetrule.rst
pdns/dnsdistdist/docs/advanced/tuning.rst
pdns/dnsdistdist/docs/reference/comboaddress.rst
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/docs/reference/constants.rst
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/docs/reference/tuning.rst
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsdistdist/docs/statistics.rst
pdns/dnsdistdist/docs/upgrade_guide.rst
pdns/dnsdistdist/ext/ipcrypt/LICENSE [new symlink]
pdns/dnsdistdist/ext/ipcrypt/Makefile.am [new symlink]
pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c [new symlink]
pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h [new symlink]
pdns/dnsdistdist/html/local.js
pdns/dnsdistdist/ipcipher.cc [new symlink]
pdns/dnsdistdist/ipcipher.hh [new symlink]
pdns/dnsdistdist/tcpiohandler.cc
pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc
pdns/dnsdistdist/test-dnsdistrules_cc.cc
pdns/dnsdistdist/test-mplexer.cc [new symlink]
pdns/dnswasher.cc
pdns/dolog.hh
pdns/epollmplexer.cc
pdns/ipcipher.cc [new file with mode: 0644]
pdns/ipcipher.hh [new file with mode: 0644]
pdns/iputils.cc
pdns/iputils.hh
pdns/ixfrdist-web.cc
pdns/ixfrdist-web.hh
pdns/ixfrdist.cc
pdns/ixfrdist.example.yml
pdns/kqueuemplexer.cc
pdns/lua-record.cc
pdns/lua-recursor4.cc
pdns/mplexer.hh
pdns/mtasker.cc
pdns/mtasker.hh
pdns/notify.cc
pdns/packethandler.cc
pdns/pdns_recursor.cc
pdns/pdnsutil.cc
pdns/pollmplexer.cc
pdns/portsmplexer.cc
pdns/rec-carbon.cc
pdns/rec-snmp.cc
pdns/rec_channel.hh
pdns/rec_channel_rec.cc
pdns/recursordist/Makefile.am
pdns/recursordist/README.md
pdns/recursordist/RECURSOR-MIB.txt
pdns/recursordist/configure.ac
pdns/recursordist/docs/changelog/4.1.rst
pdns/recursordist/docs/manpages/rec_control.1.rst
pdns/recursordist/docs/metrics.rst
pdns/recursordist/docs/settings.rst
pdns/recursordist/test-mplexer.cc [new symlink]
pdns/recursordist/test-syncres_cc.cc
pdns/responsestats-auth.cc
pdns/speedtest.cc
pdns/sstuff.hh
pdns/statbag.cc
pdns/statbag.hh
pdns/statnode.hh
pdns/syncres.cc
pdns/syncres.hh
pdns/tcpiohandler.hh
pdns/tcpreceiver.cc
pdns/test-dnsdist_cc.cc
pdns/test-dnsdistpacketcache_cc.cc
pdns/test-ipcrypt_cc.cc [new file with mode: 0644]
pdns/test-mplexer.cc [new file with mode: 0644]
pdns/test-packetcache_cc.cc
pdns/test-recpacketcache_cc.cc
pdns/unix_utility.cc
pdns/utility.hh
pdns/webserver.cc
pdns/webserver.hh
pdns/ws-api.cc
pdns/ws-api.hh
pdns/ws-auth.cc
pdns/ws-recursor.cc
regression-tests.api/test_Servers.py
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/requirements.txt
regression-tests.dnsdist/test_API.py
regression-tests.dnsdist/test_AXFR.py
regression-tests.dnsdist/test_Advanced.py
regression-tests.dnsdist/test_Basics.py
regression-tests.dnsdist/test_Caching.py
regression-tests.dnsdist/test_Carbon.py
regression-tests.dnsdist/test_DynBlocks.py
regression-tests.dnsdist/test_EDNSOptions.py
regression-tests.dnsdist/test_EDNSSelfGenerated.py
regression-tests.dnsdist/test_EdnsClientSubnet.py
regression-tests.dnsdist/test_Protobuf.py
regression-tests.dnsdist/test_RecordsCount.py
regression-tests.dnsdist/test_Responses.py
regression-tests.dnsdist/test_Routing.py
regression-tests.dnsdist/test_Spoofing.py
regression-tests.dnsdist/test_TCPKeepAlive.py
regression-tests.dnsdist/test_TCPLimits.py
regression-tests.dnsdist/test_TLS.py
regression-tests.dnsdist/test_Tags.py
regression-tests.dnsdist/test_Trailing.py
regression-tests.recursor-dnssec/test_ECS.py

diff --git a/.circleci/config.yml b/.circleci/config.yml
new file mode 100644 (file)
index 0000000..5905c91
--- /dev/null
@@ -0,0 +1,80 @@
+version: 2.1
+
+commands:
+  auth-regress:
+    description: "run one auth regression context"
+    parameters:
+      skip:
+        type: string
+        default: ""
+      context:
+        type: string
+        default: ""
+    steps:
+      - run: |
+          cd regression-tests
+          [ -e ./vars ] && . ./vars
+          rm -rf tests/*/skip
+          for t in << parameters.skip >>
+          do
+            touch tests/$t/skip
+          done
+          ./start-test-stop 5300 << parameters.context >>
+
+jobs:
+  build:
+    docker:
+      - image: debian:stretch
+      - image: mcr.microsoft.com/mssql/server:2019-CTP2.2-ubuntu
+        environment:
+          - ACCEPT_EULA: Y
+          - SA_PASSWORD: 'SAsa12%%'
+
+    steps:
+      - checkout
+
+      - run:
+          name: install dependencies
+          command: ./build-scripts/circleci.sh debian-stretch-deps
+
+      - run:
+          name: autoconf
+          command: autoreconf -vfi
+
+      - run:
+          name: configure
+          command: ./configure --disable-lua-records --with-modules='bind gmysql godbc random'
+
+      - run:
+          name: build
+          command: make -j3 -k
+
+      - run:
+          name: test gsqlite3 odbc
+          command: ./build-scripts/circleci.sh configure-odbc-sqlite; cd regression-tests ; touch tests/verify-dnssec-zone/allow-missing ; GODBC_SQLITE3_DSN=pdns-sqlite3-1 ./start-test-stop 5300 godbc_sqlite3-nsec3
+
+      - run:
+          name: set up mssql odbc
+          command: ./build-scripts/circleci.sh configure-odbc-mssql ; echo 'create database pdns' | isql -v pdns-mssql-docker-nodb sa SAsa12%%
+
+      - run:
+          name: set up mssql odbc testing
+          command: |
+            cd regression-tests
+            echo 'export GODBC_MSSQL_PASSWORD=SAsa12%% GODBC_MSSQL_USERNAME=sa GODBC_MSSQL_DSN=pdns-mssql-docker' > ./vars
+
+      - auth-regress:
+          context: godbc_mssql-nodnssec
+          skip: 8bit-txt-unescaped
+      - auth-regress:
+          context: godbc_mssql
+          skip: 8bit-txt-unescaped
+      - auth-regress:
+          context: godbc_mssql-nsec3
+          skip: 8bit-txt-unescaped
+      - auth-regress:
+          context: godbc_mssql-nsec3-optout
+          skip: 8bit-txt-unescaped verify-dnssec-zone
+      - auth-regress:
+          context: godbc_mssql-nsec3-narrow
+          skip: 8bit-txt-unescaped
index 3f85d438f6d38e35e5dee6d71407b760dd4b20e6..74a3b840b0ecbbe09385361be9aaa2e268fef93e 100644 (file)
--- a/README.md
+++ b/README.md
@@ -96,11 +96,11 @@ If you run into C++11-related symbol trouble, please try passing `CPPFLAGS=-D_GL
 
 Compiling the Recursor
 ----------------------
-See the README in pdns/recursordist.
+See [README.md](pdns/recursordist/README.md) in `pdns/recursordist/`.
 
 Compiling dnsdist
 -----------------
-See the README in pdns/dnsdistdist.
+See [README-dnsdist.md](pdns/README-dnsdist.md) in `pdns/`.
 
 Building the HTML documentation
 -------------------------------
diff --git a/build-scripts/circleci.sh b/build-scripts/circleci.sh
new file mode 100755 (executable)
index 0000000..63f6e8b
--- /dev/null
@@ -0,0 +1,77 @@
+#!/bin/sh
+set -e
+
+case $1 in
+    debian-stretch-deps)
+        apt-get update && apt-get -qq --no-install-recommends install \
+            autoconf \
+            automake \
+            bc \
+            bind9utils \
+            bison \
+            default-jre-headless \
+            default-libmysqlclient-dev \
+            dnsutils \
+            flex \
+            freetds-bin \
+            g++ \
+            git \
+            ldnsutils \
+            libboost-all-dev \
+            libsqliteodbc \
+            libssl-dev \
+            libtool \
+            make \
+            pkg-config \
+            ragel \
+            sqlite3 \
+            tdsodbc \
+            unbound-host \
+            unixodbc \
+            unixodbc-dev \
+            virtualenv \
+            wget
+
+        wget https://github.com/dblacka/jdnssec-tools/releases/download/0.14/jdnssec-tools-0.14.tar.gz
+        tar xfz jdnssec-tools-0.14.tar.gz --strip-components=1 -C /
+        rm jdnssec-tools-0.14.tar.gz
+
+        ;;
+    configure-odbc-sqlite)
+        cat >> ~/.odbc.ini << __EOF__
+[pdns-sqlite3-1]
+Driver = SQLite3
+Database = ${PWD}/regression-tests/pdns.sqlite3
+
+[pdns-sqlite3-2]
+Driver = SQLite3
+Database = ${PWD}/regression-tests/pdns.sqlite32
+
+__EOF__
+        ;;
+    configure-odbc-mssql)
+        cat >> ~/.odbc.ini << __EOF__
+[pdns-mssql-docker]
+Driver=FreeTDS
+Trace=No
+Server=127.0.0.1
+Port=1433
+Database=pdns
+TDS_Version=7.1
+
+[pdns-mssql-docker-nodb]
+Driver=FreeTDS
+Trace=No
+Server=127.0.0.1
+Port=1433
+TDS_Version=7.1
+
+__EOF__
+
+        cat /usr/share/tdsodbc/odbcinst.ini >> /etc/odbcinst.ini
+        ;;
+    *)
+        echo unknown command "$1"
+        exit 1
+        ;;
+esac
\ No newline at end of file
index b51e0c528501247fc133c83af5fab641e7e89bfa..91a2bd57b721ab4233e7ba89af8d6c33fc4e72eb 100644 (file)
@@ -1,5 +1,5 @@
 FROM dist-base as package-builder
-RUN yum install -y rpm-build rpmdevtools python34 && \
+RUN yum install -y rpm-build rpmdevtools /usr/bin/python3 && \
     yum groupinstall -y "Development Tools" && \
     rpmdev-setuptree
 
index 6325b069041bae0c0c3c4dc268e0d0ff8e06ce5a..1c00fff9756336949b49656bffb9098010cf281d 100644 (file)
@@ -17,8 +17,9 @@ AM_INIT_AUTOMAKE([foreign dist-bzip2 no-dist-gzip tar-ustar -Wno-portability sub
 AM_SILENT_RULES([yes])
 
 AC_CANONICAL_HOST
-: ${CFLAGS="-Wall -g -O2"}
-: ${CXXFLAGS="-Wall -g -O2"}
+# Add some default CFLAGS and CXXFLAGS, can be appended to using the environment variables
+CFLAGS="-Wall -g -O2 $CFLAGS"
+CXXFLAGS="-Wall -g -O2 $CXXFLAGS"
 
 AC_PROG_CC
 AM_PROG_CC_C_O
@@ -106,7 +107,7 @@ PDNS_CHECK_LIBCRYPTO_EDDSA
 PDNS_CHECK_RAGEL([pdns/dnslabeltext.cc], [www.powerdns.com])
 PDNS_CHECK_CLOCK_GETTIME
 
-BOOST_REQUIRE([1.35])
+BOOST_REQUIRE([1.42])
 # Boost accumulators, as used by dnsbulktest and dnstcpbench, need 1.48+
 # to be compatible with C++11
 AM_CONDITIONAL([HAVE_BOOST_GE_148], [test "$boost_major_version" -ge 148])
@@ -296,6 +297,8 @@ AC_SUBST([AM_CPPFLAGS],
 
 AC_SUBST([YAHTTP_CFLAGS], ['-I$(top_srcdir)/ext/yahttp'])
 AC_SUBST([YAHTTP_LIBS], ['$(top_builddir)/ext/yahttp/yahttp/libyahttp.la'])
+AC_SUBST([IPCRYPT_CFLAGS], ['-I$(top_srcdir)/ext/ipcrypt'])
+AC_SUBST([IPCRYPT_LIBS], ['$(top_builddir)/ext/ipcrypt/libipcrypt.la'])
 
 CXXFLAGS="$SANITIZER_FLAGS $CXXFLAGS"
 
@@ -314,6 +317,7 @@ AC_CONFIG_FILES([
   docs/Makefile
   pdns/pdns.init
   ext/Makefile
+  ext/ipcrypt/Makefile
   ext/yahttp/Makefile
   ext/yahttp/yahttp/Makefile
   ext/json11/Makefile
index c04070d58c30588f8af551215c9d2cdcc82def75..02d39fe999daab71f491c8aafaa3dab87584db54 100644 (file)
@@ -123,6 +123,16 @@ Use the InnoDB READ-COMMITTED transaction isolation level. Default: yes.
 The timeout in seconds for each attempt to read from, or write to the
 server. A value of 0 will disable the timeout. Default: 10
 
+.. _setting-gmysql-thread-cleanup:
+
+``gmysql-thread-cleanup``
+^^^^^^^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.1.8
+
+Older versions (such as those shipped on RHEL 7) of the MySQL/MariaDB client libraries leak memory unless applications explicitly report the end of each thread to the library. Enabling ``gmysql-thread-cleanup`` tells PowerDNS to call ``mysql_thread_end()`` whenever a thread ends.
+
+Only enable this if you are certain you need to. For more discussion, see https://github.com/PowerDNS/pdns/issues/6231.
+
 Default Schema
 --------------
 
index cb396fb7e916e895cd591ba3aed8be385cb767bb..824233413d129067b2bfdd0804f3f5f801df305b 100644 (file)
@@ -79,9 +79,9 @@ The password to for :ref:`setting-gpgsql-user`. Default: not set.
 
 Enable DNSSEC processing for this backend. Default: no.
 
-.. _setting-gpsql-extra-connection-parameters:
+.. _setting-gpgsql-extra-connection-parameters:
 
-``gpsql-extra-connection-parameters``
+``gpgsql-extra-connection-parameters``
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Extra connection parameters to forward to postgres. If you want to pin a
index ac1236b6364f105fa0fb2c32dea1ed6ada19e040..0b37c2614c98c5384efa2e7d28fca77e756f1ea4 100644 (file)
@@ -1,6 +1,79 @@
 Changelogs for 4.1.x
 ====================
 
+.. changelog::
+  :version: 4.1.8
+  :released: March 22nd 2019
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7604
+    :tickets: 7494
+
+    Correctly interpret an empty AXFR response to an IXFR query.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7610
+    :tickets: 7341
+
+    Fix replying from ANY address for non-standard port.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7609
+    :tickets: 7580
+
+    Fix rectify for ENT records in narrow zones.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7607
+    :tickets: 7472
+
+    Do not compress the root.
+
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7608
+    :tickets: 7459
+
+    Fix dot stripping in ``setcontent()``.
+
+  .. change::
+    :tags: Bug Fixes, MySQL
+    :pullreq: 7605
+    :tickets: 7496
+
+    Fix invalid SOA record in MySQL which prevented the authoritative server from starting.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7603
+    :tickets: 7294
+
+    Prevent leak of file descriptor if running out of ports for incoming AXFR.
+
+  .. change::
+    :tags: Bug Fixes, API
+    :pullreq: 7602
+    :tickets: 7546
+
+    Fix API search failed with "Commands out of sync; you can't run this command now".
+
+  .. change::
+    :tags: Bug Fixes, MySQL
+    :pullreq: 7509
+    :tickets: 7517
+
+    Plug ``mysql_thread_init`` memory leak.
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7567
+
+    EL6: fix ``CXXFLAGS`` to build with compiler optimizations.
+
 .. changelog::
   :version: 4.1.7
   :released: March 18th 2019
index 06a66a6ed06bf1b88a75ae8ba4fbc78e0458ed19..859a0c111cf61bdde8f3232804dd079777f9132a 100644 (file)
@@ -5,6 +5,32 @@ Changelogs for 4.2.x
   :version: 4.2.0
   :released: *unreleased*
 
+  .. change::
+    :tags: Bug Fixes
+    :pullreq: 7576
+    :tickets: 7573
+
+    Insufficient validation in the HTTP remote backend (CVE-2019-3871, PowerDNS Security Advisory :doc:`2019-03 <../security-advisories/powerdns-advisory-2019-03>`)
+
+  .. change::
+    :tags: Bug Fixes, API
+    :pullreq: 7546
+    :tickets: 7545
+
+    Fix API search failed with "Commands out of sync; you can't run this command now".
+
+  .. change::
+    :tags: Bug Fixes, GeoIP
+    :pullreq: 7219
+
+    Fix static lookup when using weighted records on multiple record types.
+
+  .. change::
+    :tags: Improvements, DNSSEC
+    :pullreq: 7516
+
+    Report ``checkKey`` errors upwards.
+
   .. change::
     :tags: Bug Fixes, MySQL
     :pullreq: 7496
index 761980757e5649eb6a913e7db3c9f01bddb10e4a..771dae08b9fb31ad47d79840762e6be406ecba87 100644 (file)
@@ -1,7 +1,7 @@
 Statistics endpoint
 ===================
 
-.. http:get:: /api/v1/servers/:server_id/statistics
+.. http:get:: /api/v1/servers/:server_id/statistics?statistic=:statistic
 
   Query PowerDNS internal statistics.
   Returns a list of :json:object:`StatisticItem` elements.
@@ -10,6 +10,10 @@ Statistics endpoint
 
   :param server_id: The name of the server
 
+  .. versionadded:: 4.2.0
+
+  :query statistic: If set to the name of a specific statistic, only this value is returned. If no statistic with that name exists, the response has a 422 status and an error message
+
   **Example response:**
 
   .. code-block:: json
index c89b6e55713fbca9a5f5916d8306a0dc57a5ddf4..f6679dafbd3fe8ebe5f10dfb8603b34823f7bc91 100644 (file)
@@ -7,7 +7,7 @@ authoritative server. PowerDNS supports this mode fully.
 
 In addition, PowerDNS supports taking care of the signing itself, in
 which case PowerDNS operates differently from most tutorials and
-handbooks. This mode is easier however.
+handbooks. This mode is easier, however.
 
 For relevant tradeoffs, please see :doc:`../security` and
 :doc:`../performance`.
@@ -18,16 +18,16 @@ Online Signing
 --------------
 
 In the simplest situation, there is a single "SQL" database that
-contains, in separate tables, all domain data, keying material and other
+contains, in separate tables, all domain data, keying material, and other
 DNSSEC related settings.
 
 This database is then replicated to all PowerDNS instances, which all
-serve identical records, keys and signatures.
+serve identical records, keys, and signatures.
 
 In this mode of operation, care should be taken that the database
 replication occurs over a secure network, or over an encrypted
-connection. This is because keying material, if intercepted, could be
-used to counterfeit DNSSEC data using the original keys.
+connection. If intercepted, keying material could be used to counterfeit
+DNSSEC data using the original keys.
 
 Such a single replicated database requires no further attention beyond
 monitoring already required during non-DNSSEC operations.
@@ -45,17 +45,17 @@ Zone Signing Keys (ZSKs). During normal operations, this means that only
 1 ZSK is 'active', and the other is inactive.
 
 Should it be desired to 'roll over' to a new key, both keys can
-temporarily be active (and used for signing), and after a while the old
-key can be inactivated. Subsequently it can be removed.
+temporarily be active (and used for signing), and after a while, the old
+key can be deactivated. Subsequently, it can be removed.
 
-As elucidated above, there are several ways in which DNSSEC can deny the
-existence of a record, and this setting too is stored away from zone
-records, and lives with the DNSSEC keying material.
+As described above, there are several ways in which DNSSEC can deny the
+existence of a record, and this setting, which is also stored away from zone
+records, lives with the DNSSEC keying material.
 
 (Hashed) Denial of Existence
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-PowerDNS supports unhashed secure denial of existence using NSEC
+PowerDNS supports unhashed secure denial-of-existence using NSEC
 records. These are generated with the help of the (database) backend,
 which needs to be able to supply the 'previous' and 'next' records in
 canonical ordering.
@@ -63,7 +63,7 @@ canonical ordering.
 The Generic SQL Backends have fields that allow them to supply these
 relative record names.
 
-In addition, hashed secure denial of existence is supported using NSEC3
+In addition, hashed secure denial-of-existence is supported using NSEC3
 records, in two modes, one with help from the database, the other with
 the help of some additional calculations.
 
@@ -72,7 +72,7 @@ where the backend should be able to supply the previous and next domain
 names in hashed order.
 
 NSEC3 in 'narrow' mode uses additional hashing calculations to provide
-hashed secure denial of existence 'on the fly', without further
+hashed secure denial-of-existence 'on the fly', without further
 involving the database.
 
 .. _dnssec-signatures:
@@ -84,8 +84,8 @@ In PowerDNS live signing mode, signatures, as served through RRSIG
 records, are calculated on the fly, and heavily cached. All CPU cores
 are used for the calculation.
 
-RRSIGs have a validity period, in PowerDNS this period is 3 weeks.
-This period starts at most a week in the past, and continues at least a week into the future.
+RRSIGs have a validity period. In PowerDNS, the RRSIG validity period is 3 weeks.
+This period starts at most a week in the past and continues at least a week into the future.
 This interval jumps with one-week increments every Thursday.
 
 The time period used is always calculated based on the moment of rollover.
@@ -105,7 +105,7 @@ Graphically, it looks like this::
                               |----- RRSIG(1) served -----|----- RRSIG(2) served -----|
 
   |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
-  thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu fri sat sun mon tue wed thu
+  Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu Fri Sat Sun Mon Tue Wed Thu
                                                           ^
                                                           |
                                                           RRSIG roll-over(1 to 2)
@@ -115,8 +115,8 @@ At all times, only one RRSIG per signed RRset per ZSK is served when responding
 .. note::
   Why Thursday? POSIX-based operating systems count the time
   since GMT midnight January 1st of 1970, which was a Thursday. PowerDNS
-  inception/expiration times are generated based on an integral number of
-  weeks having passed since the start of the 'epoch'.
+  inception/expiration times are generated based on the integral number of
+  weeks since the start of the 'epoch'.
 
 PowerDNS also serves the DNSKEY records in live-signing mode. Their TTL
 is derived from the SOA records *minimum* field. When using NSEC3, the
@@ -127,7 +127,7 @@ Pre-signed records
 
 In this mode, PowerDNS serves zones that already contain DNSSEC records.
 Such zones can either be slaved from a remote master, or can be signed
-using tools like OpenDNSSEC, ldns-signzone or dnssec-signzone.
+using tools like OpenDNSSEC, ldns-signzone, and dnssec-signzone.
 
 Even in this mode, PowerDNS will synthesize NSEC(3) records itself
 because of its architecture. RRSIGs of these NSEC(3) will still need to
@@ -155,7 +155,7 @@ required for the receiving party to rectify the zone without knowing the
 keys, such as signed NSEC3 records for empty non-terminals. The zone is
 not required to be rectified on the master.
 
-Signatures and Hashing is similar as described in :ref:`dnssec-online-signing`.
+The signing and hashing algorithms are described in :ref:`dnssec-online-signing`.
 
 .. _dnssec-modes-bind-mode:
 
@@ -171,7 +171,7 @@ To use this mode, add
 restart PowerDNS.
 
 .. note::
-  This sqlite database is different from the database used for the regular :doc:`SQLite 3 backend <../backends/generic-sqlite3>`.
+  This SQLite database is different from the database used for the regular :doc:`SQLite 3 backend <../backends/generic-sqlite3>`.
 
 After this, you can use ``pdnsutil secure-zone`` and all other pdnsutil
 commands on your BIND zones without trouble.
@@ -182,7 +182,7 @@ Hybrid BIND-mode operation
 --------------------------
 
 PowerDNS can also operate based on 'BIND'-style zone & configuration
-files. This 'bindbackend' has full knowledge of DNSSEC, but has no
+files. This 'bindbackend' has full knowledge of DNSSEC but has no
 native way of storing keying material.
 
 However, since PowerDNS supports operation with multiple simultaneous
index faddbb2497af9646a0d59159c3bc5d0632ad0d07..63fcb3c4d65d01a232427a9e8bc7331d2d04bc0c 100644 (file)
@@ -399,6 +399,14 @@ paths:
           required: true
           description: The id of the server to retrieve
           type: string
+        - name: statistic
+          in: query
+          required: false
+          type: string
+          description: |
+            When set to the name of a specific statistic, only this value is returned.
+            If no statistic with that name exists, the response has a 422 status and an error message.
+
       responses:
         '200':
           description: List of Statistic Items
@@ -408,6 +416,8 @@ paths:
             - $ref: '#/definitions/StatisticItem'
             - $ref: '#/definitions/MapStatisticItem'
             - $ref: '#/definitions/RingStatisticItem'
+        '422':
+          description: 'Returned when a non-existing statistic name has been requested. Contains an error message'
 
   '/servers/{server_id}/search-data':
     get:
index 46f5de42d818ab5cca86a28ad0fb15a241cb2cd4..1b20806226cf48d2b45030e5738f1ed1e379c554 100644 (file)
@@ -25,7 +25,11 @@ about.
 Options
 -------
 
-None
+--decrypt,-d             Undo IPCipher encryption of IP addresses
+--help, -h               Show summary of options.
+--key,-k                 Base64 encoded 128-bit key for IPCipher
+--passphrase,-p          Passphrase that will be used to derive an IPCipher key
+--version,-v             Output version
 
 See also
 --------
index 22b48d80f5707f5904129cd3ca5df00552d2f30d..121a696ef2cd73b2a6af3db2b5a22849a983922a 100644 (file)
@@ -113,6 +113,14 @@ Options
   Entries without a netmask will be interpreted as a single address.
   By default, this list is set to ``127.0.0.0/8`` and ``::1/128``.
 
+:webserver-loglevel:
+  How much the webserver should log: 'none', 'normal' or 'detailed'.
+  When logging, each log-line contains the UUID of the request, this allows finding errors caused by certain requests.
+  With 'none', nothing is logged except for errors.
+  With 'normal' (the default), one line per request is logged in the style of the common log format::
+    [NOTICE] [webserver] 46326eef-b3ba-4455-8e76-15ec73879aa3 127.0.0.1:57566 "GET /metrics HTTP/1.1" 200 1846
+  with 'detailed', the full requests and responses (including headers) are logged along with the regular log-line from 'normal'.
+
 See also
 --------
 
index 2a5511665b98deeeb3c0b5c29d246c27cd481d8a..3b02ccfbc75d3e3011c41bb237477b4bc35cbe84 100644 (file)
@@ -232,6 +232,14 @@ bench-db [*FILE*]
     *FILE* can be a file with a list, one per line, of domain names to use for this.
     If *FILE* is not specified, powerdns.com is used.
 
+OTHER TOOLS
+-----------
+ipencrypt *IP-ADDRESS* passsword
+    Encrypt an IP address according to the 'ipcipher' standard
+
+ipdecrypt *IP-ADDRESS* passsword
+    Encrypt an IP address according to the 'ipcipher' standard
+
 See also
 --------
 
index 400790aa8e91d50a00b4d80ef38a8d5e0dc739fa..83c2a0ef8f86e5f30035d2fa6188fe23bda1f195 100644 (file)
@@ -1,4 +1,4 @@
-@       86400   IN  SOA pdns-public-ns1.powerdns.com. pieter\.lexis.powerdns.com. 2019031801 10800 3600 604800 10800
+@       86400   IN  SOA pdns-public-ns1.powerdns.com. pieter\.lexis.powerdns.com. 2019040201 10800 3600 604800 10800
 @       3600    IN  NS  pdns-public-ns1.powerdns.com.
 @       3600    IN  NS  pdns-public-ns2.powerdns.com.
 ; Auth
@@ -44,8 +44,10 @@ auth-4.1.4.security-status                              60 IN TXT "3 Upgrade now
 auth-4.1.5.security-status                              60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.1.6.security-status                              60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
 auth-4.1.7.security-status                              60 IN TXT "1 OK"
-auth-4.2.0-alpha1.security-status                       60 IN TXT "1 OK"
-auth-4.2.0-beta1.security-status                        60 IN TXT "1 OK"
+auth-4.1.8.security-status                              60 IN TXT "1 OK"
+auth-4.2.0-alpha1.security-status                       60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
+auth-4.2.0-beta1.security-status                        60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/authoritative/security-advisories/powerdns-advisory-2019-03.html"
+auth-4.2.0-rc1.security-status                          60 IN TXT "1 OK"
 
 ; Auth Debian
 auth-3.4.1-2.debian.security-status                     60 IN TXT "3 Upgrade now, see https://doc.powerdns.com/3/security/powerdns-advisory-2015-01/ and https://doc.powerdns.com/3/security/powerdns-advisory-2015-02/ and https://doc.powerdns.com/3/security/powerdns-advisory-2016-02/ and https://doc.powerdns.com/3/security/powerdns-advisory-2016-03/ and https://doc.powerdns.com/3/security/powerdns-advisory-2016-04/ and https://doc.powerdns.com/3/security/powerdns-advisory-2016-05/"
@@ -178,6 +180,7 @@ recursor-4.1.8.security-status                          60 IN TXT "3 Upgrade now
 recursor-4.1.9.security-status                          60 IN TXT "1 OK"
 recursor-4.1.10.security-status                         60 IN TXT "1 OK"
 recursor-4.1.11.security-status                         60 IN TXT "1 OK"
+recursor-4.1.12.security-status                         60 IN TXT "1 OK"
 recursor-4.2.0-alpha1.security-status                   60 IN TXT "1 OK"
 
 ; Recursor Debian
index a622d765825064af3ffbe772bc6ee2bb8c7a3dc7..eab1244d270046cbe771b179dcf3c2b2f772292c 100644 (file)
@@ -1615,6 +1615,47 @@ IP Address for webserver/API to listen on.
 
 Webserver/API access is only allowed from these subnets.
 
+.. _setting-webserver-loglevel:
+
+``webserver-loglevel``
+----------------------
+.. versionadded:: 4.2.0
+
+-  String, one of "none", "normal", "detailed"
+
+The amount of logging the webserver must do. "none" means no useful webserver information will be logged.
+When set to "normal", the webserver will log a line per request that should be familiar::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+When set to "detailed", all information about the request and response are logged::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Request Details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-encoding: gzip, deflate
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-language: en-US,en;q=0.5
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   connection: keep-alive
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   dnt: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   host: 127.0.0.1:8081
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   upgrade-insecure-requests: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   user-agent: Mozilla/5.0 (X11; Linux x86_64; rv:64.0) Gecko/20100101 Firefox/64.0
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  No body
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Response details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Connection: close
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Length: 49
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Type: text/html; charset=utf-8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Server: PowerDNS/0.0.15896.0.gaba8bab3ab
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Full body: 
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   <!html><title>Not Found</title><h1>Not Found</h1>
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+The value between the hooks is a UUID that is generated for each request. This can be used to find all lines related to a single request.
+
+.. note::
+  The webserver logs these line on the NOTICE level. The :ref:`settings-loglevel` seting must be 5 or higher for these lines to end up in the log.
+
 .. _setting-webserver-password:
 
 ``webserver-password``
index 49dc5a219216ad3860bb46a49888ee0f280db098..7c0a42d4199fc6feaaaf3168da99b1e499b1e689 100644 (file)
@@ -1,10 +1,12 @@
 SUBDIRS = \
-       yahttp \
-        json11
+       ipcrypt \
+       json11 \
+       yahttp
 
 DIST_SUBDIRS = \
-       yahttp \
-        json11
+       ipcrypt \
+       json11 \
+       yahttp
 
 EXTRA_DIST = \
        luawrapper/include/LuaContext.hpp
diff --git a/ext/ipcrypt/.gitignore b/ext/ipcrypt/.gitignore
new file mode 100644 (file)
index 0000000..24ad051
--- /dev/null
@@ -0,0 +1,5 @@
+*.la
+*.lo
+*.o
+Makefile
+Makefile.in
diff --git a/ext/ipcrypt/LICENSE b/ext/ipcrypt/LICENSE
new file mode 100644 (file)
index 0000000..0a199e5
--- /dev/null
@@ -0,0 +1,14 @@
+Copyright (c) 2015-2018, Frank Denis <j at pureftpd dot org>
+
+Permission to use, copy, modify, and/or distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
diff --git a/ext/ipcrypt/Makefile.am b/ext/ipcrypt/Makefile.am
new file mode 100644 (file)
index 0000000..a68b2d4
--- /dev/null
@@ -0,0 +1,5 @@
+noinst_LTLIBRARIES = libipcrypt.la
+
+libipcrypt_la_SOURCES = \
+       ipcrypt.c \
+       ipcrypt.h
diff --git a/ext/ipcrypt/ipcrypt.c b/ext/ipcrypt/ipcrypt.c
new file mode 100644 (file)
index 0000000..6ef464a
--- /dev/null
@@ -0,0 +1,87 @@
+
+#include "ipcrypt.h"
+
+#define ROTL(X, R) (X) = (unsigned char) ((X) << (R)) | ((X) >> (8 - (R)))
+
+static void
+arx_fwd(unsigned char state[4])
+{
+    state[0] += state[1];
+    state[2] += state[3];
+    ROTL(state[1], 2);
+    ROTL(state[3], 5);
+    state[1] ^= state[0];
+    state[3] ^= state[2];
+    ROTL(state[0], 4);
+    state[0] += state[3];
+    state[2] += state[1];
+    ROTL(state[1], 3);
+    ROTL(state[3], 7);
+    state[1] ^= state[2];
+    state[3] ^= state[0];
+    ROTL(state[2], 4);
+}
+
+static void
+arx_bwd(unsigned char state[4])
+{
+    ROTL(state[2], 4);
+    state[1] ^= state[2];
+    state[3] ^= state[0];
+    ROTL(state[1], 5);
+    ROTL(state[3], 1);
+    state[0] -= state[3];
+    state[2] -= state[1];
+    ROTL(state[0], 4);
+    state[1] ^= state[0];
+    state[3] ^= state[2];
+    ROTL(state[1], 6);
+    ROTL(state[3], 3);
+    state[0] -= state[1];
+    state[2] -= state[3];
+}
+
+static inline void
+xor4(unsigned char *out, const unsigned char *x, const unsigned char *y)
+{
+    out[0] = x[0] ^ y[0];
+    out[1] = x[1] ^ y[1];
+    out[2] = x[2] ^ y[2];
+    out[3] = x[3] ^ y[3];
+}
+
+int
+ipcrypt_encrypt(unsigned char out[IPCRYPT_BYTES],
+                const unsigned char in[IPCRYPT_BYTES],
+                const unsigned char key[IPCRYPT_KEYBYTES])
+{
+    unsigned char state[4];
+
+    xor4(state, in, key);
+    arx_fwd(state);
+    xor4(state, state, key + 4);
+    arx_fwd(state);
+    xor4(state, state, key + 8);
+    arx_fwd(state);
+    xor4(out, state, key + 12);
+
+    return 0;
+}
+
+int
+ipcrypt_decrypt(unsigned char out[IPCRYPT_BYTES],
+                const unsigned char in[IPCRYPT_BYTES],
+                const unsigned char key[IPCRYPT_KEYBYTES])
+{
+    unsigned char state[4];
+
+    xor4(state, in, key + 12);
+    arx_bwd(state);
+    xor4(state, state, key + 8);
+    arx_bwd(state);
+    xor4(state, state, key + 4);
+    arx_bwd(state);
+    xor4(out, state, key);
+
+    return 0;
+}
diff --git a/ext/ipcrypt/ipcrypt.h b/ext/ipcrypt/ipcrypt.h
new file mode 100644 (file)
index 0000000..76b94f5
--- /dev/null
@@ -0,0 +1,24 @@
+
+#ifndef ipcrypt_H
+#define ipcrypt_H
+
+#define IPCRYPT_BYTES 4
+#define IPCRYPT_KEYBYTES 16
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int ipcrypt_encrypt(unsigned char out[IPCRYPT_BYTES],
+                    const unsigned char in[IPCRYPT_BYTES],
+                    const unsigned char key[IPCRYPT_KEYBYTES]);
+
+int ipcrypt_decrypt(unsigned char out[IPCRYPT_BYTES],
+                    const unsigned char in[IPCRYPT_BYTES],
+                    const unsigned char key[IPCRYPT_KEYBYTES]);
+
+#ifdef __cplusplus
+}  /* End of the 'extern "C"' block */
+#endif
+
+#endif
index b0e6a39e4cf1b2f6eb5b4223f23e3d790429d099..c71c98acc8208cb22fee669e51cc582f2c731a0d 100644 (file)
@@ -90,6 +90,10 @@ AC_DEFUN([PDNS_CHECK_LIBCRYPTO], [
         # it will just work!
     fi
 
+    if $found; then
+        AC_DEFINE([HAVE_LIBCRYPTO], [1], [Define to 1 if you have OpenSSL libcrypto])
+    fi
+
     # try the preprocessor and linker with our new flags,
     # being careful not to pollute the global LIBS, LDFLAGS, and CPPFLAGS
 
@@ -120,4 +124,5 @@ AC_DEFUN([PDNS_CHECK_LIBCRYPTO], [
     AC_SUBST([LIBCRYPTO_INCLUDES])
     AC_SUBST([LIBCRYPTO_LIBS])
     AC_SUBST([LIBCRYPTO_LDFLAGS])
+    AM_CONDITIONAL([HAVE_LIBCRYPTO], [test "x$LIBCRYPTO_LIBS" != "x"])
 ])
index e3033abe3b1a7a2a68b5ca93b01b17cf4b45451c..48707d3a95f250ea6780d2c21484fb2a05a02968 100644 (file)
@@ -59,7 +59,8 @@ void gMySQLBackend::reconnect()
                    getArg("password"),
                    getArg("group"),
                    mustDo("innodb-read-committed"),
-                   getArgAsNum("timeout")));
+                   getArgAsNum("timeout"),
+                   mustDo("thread-cleanup")));
 }
 
 class gMySQLFactory : public BackendFactory
@@ -78,6 +79,7 @@ public:
     declare(suffix,"group", "Database backend MySQL 'group' to connect as", "client");
     declare(suffix,"innodb-read-committed","Use InnoDB READ-COMMITTED transaction isolation level","yes");
     declare(suffix,"timeout", "The timeout in seconds for each attempt to read/write to the server", "10");
+    declare(suffix,"thread-cleanup","Explicitly call mysql_thread_end() when threads end","no");
 
     declare(suffix,"dnssec","Enable DNSSEC processing","no");
 
index 1d851daa91e3094a8fc637a35f6ee1220963a2f0..0b062574776f7bfe9ccc0232c9849d5113df8364 100644 (file)
 typedef bool my_bool;
 #endif
 
+/*
+ * Older versions of the MySQL and MariaDB client leak memory
+ * because they expect the application to call mysql_thread_end()
+ * when a thread ends. This thread_local static object provides
+ * that closure, but only when the user has asked for it
+ * by setting gmysql-thread-cleanup.
+ * For more discussion, see https://github.com/PowerDNS/pdns/issues/6231
+ */
+class MySQLThreadCloser
+{
+public:
+  ~MySQLThreadCloser() {
+    if(d_enabled) {
+      mysql_thread_end();
+    }
+  }
+  void enable() {
+   d_enabled = true;
+  }
+
+private:
+  bool d_enabled = false;
+};
+
+static thread_local MySQLThreadCloser threadcloser;
+
 bool SMySQL::s_dolog;
 pthread_mutex_t SMySQL::s_myinitlock = PTHREAD_MUTEX_INITIALIZER;
 
@@ -419,6 +445,10 @@ void SMySQL::connect()
   int retry=1;
 
   Lock l(&s_myinitlock);
+  if (d_threadCleanup) {
+    threadcloser.enable();
+  }
+
   if (!mysql_init(&d_db))
     throw sPerrorException("Unable to initialize mysql driver");
 
@@ -467,8 +497,8 @@ void SMySQL::connect()
 }
 
 SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const string &msocket, const string &user,
-               const string &password, const string &group, bool setIsolation, unsigned int timeout):
-  d_database(database), d_host(host), d_msocket(msocket), d_user(user), d_password(password), d_group(group), d_timeout(timeout), d_port(port), d_setIsolation(setIsolation)
+               const string &password, const string &group, bool setIsolation, unsigned int timeout, bool threadCleanup):
+  d_database(database), d_host(host), d_msocket(msocket), d_user(user), d_password(password), d_group(group), d_timeout(timeout), d_port(port), d_setIsolation(setIsolation), d_threadCleanup(threadCleanup)
 {
   connect();
 }
index 7e31d7b5644f1da8d405318253abe22e08f6d305..7a33e8c529d12357430a4cf3418320a6a3860eca 100644 (file)
@@ -32,7 +32,8 @@ public:
   SMySQL(const string &database, const string &host="", uint16_t port=0,
          const string &msocket="",const string &user="",
          const string &password="", const string &group="",
-         bool setIsolation=false, unsigned int timeout=10);
+         bool setIsolation=false, unsigned int timeout=10,
+         bool threadCleanup=false);
 
   ~SMySQL();
 
@@ -61,6 +62,7 @@ private:
   unsigned int d_timeout;
   uint16_t d_port;
   bool d_setIsolation;
+  bool d_threadCleanup;
 };
 
 #endif /* SSMYSQL_HH */
index a64dd1110af1f9d4c76c8d0bf04b64de0673b6bd..e6da3dd7d21d9397c645545bcf270d4ebc887f81 100644 (file)
@@ -256,7 +256,7 @@ public:
       return count;
     }
 
-    //! End iderator type
+    //! End iterator type
     struct eiter_t
     {};
 
index 6cec85188deaabd01dded97d2090a87f69ce88f3..525beb897ba1c7a2c80eb64a602eaba1b4cc3294 100644 (file)
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
-#include "pdns/utility.hh"
 #include "pdns/dnsbackend.hh"
 #include "pdns/dns.hh"
 #include "pdns/dnsbackend.hh"
 #include "pdns/dnspacket.hh"
+#include "pdns/dns_random.hh"
 #include "pdns/pdnsexception.hh"
 #include "pdns/logger.hh"
 #include "pdns/version.hh"
@@ -59,7 +59,7 @@ public:
     } else if (qdomain == d_ourname) {
       if(type.getCode() == QType::A || type.getCode() == QType::ANY) {
         ostringstream os;
-        os<<Utility::random()%256<<"."<<Utility::random()%256<<"."<<Utility::random()%256<<"."<<Utility::random()%256;
+        os<<dns_random(256)<<"."<<dns_random(256)<<"."<<dns_random(256)<<"."<<dns_random(256);
         d_answer=os.str(); // our random ip address
       } else {
         d_answer="";
index 40030a515e85c4296c0dc84e69c388932b14c848..6749e4cf6e15e6d3aa4608ff818209b56a531ea4 100644 (file)
@@ -223,6 +223,7 @@ pdns_server_SOURCES = \
        tsigutils.hh tsigutils.cc \
        tkey.cc \
        ueberbackend.cc ueberbackend.hh \
+       uuid-utils.hh uuid-utils.cc \
        unix_semaphore.cc \
        unix_utility.cc \
        utility.hh \
@@ -312,6 +313,7 @@ pdnsutil_SOURCES = \
        ednsoptions.cc ednsoptions.hh \
        ednssubnet.cc \
        gss_context.cc gss_context.hh \
+       ipcipher.cc ipcipher.hh \
        iputils.cc iputils.hh \
        json.cc \
        logger.cc \
@@ -349,7 +351,8 @@ pdnsutil_LDADD = \
        $(JSON11_LIBS) \
        $(LIBDL) \
        $(BOOST_PROGRAM_OPTIONS_LIBS) \
-       $(LIBCRYPTO_LIBS)
+       $(LIBCRYPTO_LIBS) \
+       $(IPCRYPT_LIBS)
 
 if LIBSODIUM
 pdnsutil_SOURCES += sodiumsigners.cc
@@ -640,6 +643,7 @@ ixfrdist_SOURCES = \
        threadname.hh threadname.cc \
        tsigverifier.cc tsigverifier.hh \
        unix_utility.cc \
+       uuid-utils.hh uuid-utils.cc \
        webserver.hh webserver.cc \
        zoneparser-tng.cc
 
@@ -880,18 +884,22 @@ speedtest_LDADD = $(LIBCRYPTO_LIBS) \
        $(RT_LIBS)
 
 dnswasher_SOURCES = \
+       base64.cc \
        dnslabeltext.cc \
        dnsname.hh dnsname.cc \
        dnsparser.hh \
        dnspcap.cc dnspcap.hh \
        dnswasher.cc \
        dnswriter.hh \
+       ipcipher.cc ipcipher.hh \
        logger.cc \
        misc.cc \
        qtype.cc \
        statbag.cc \
        unix_utility.cc
 
+dnswasher_LDFLAGS =    $(AM_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LDFLAGS) $(LIBCRYPTO_LDFLAGS)
+dnswasher_LDADD =      $(BOOST_PROGRAM_OPTIONS_LIBS) $(LIBCRYPTO_LIBS) $(IPCRYPT_LIBS)
 
 dnsbulktest_SOURCES = \
        base32.cc \
@@ -1058,7 +1066,8 @@ pdns_notify_SOURCES = \
        rcpgenerator.cc rcpgenerator.hh \
        sillyrecords.cc \
        statbag.cc \
-       unix_utility.cc
+       unix_utility.cc \
+       dns_random.cc
 
 pdns_notify_LDFLAGS = \
        $(AM_LDFLAGS) \
@@ -1069,6 +1078,10 @@ pdns_notify_LDADD = \
        $(LIBCRYPTO_LIBS) \
        $(BOOST_PROGRAM_OPTIONS_LIBS)
 
+if LIBSODIUM
+pdns_notify_LDADD += $(LIBSODIUM_LIBS)
+endif
+
 dnsscope_SOURCES = \
        arguments.cc \
        base32.cc \
@@ -1267,6 +1280,7 @@ testrunner_SOURCES = \
        ednssubnet.cc \
        gettime.cc gettime.hh \
        gss_context.cc gss_context.hh \
+       ipcipher.cc ipcipher.hh \
        iputils.cc \
        ixfr.cc ixfr.hh \
        logger.cc \
@@ -1277,6 +1291,7 @@ testrunner_SOURCES = \
        nameserver.cc \
        nsecrecords.cc \
        opensslsigners.cc opensslsigners.hh \
+       pollmplexer.cc \
        qtype.cc \
        rcpgenerator.cc \
        responsestats.cc \
@@ -1296,11 +1311,13 @@ testrunner_SOURCES = \
        test-dnsparser_cc.cc \
        test-dnsparser_hh.cc \
        test-dnsrecords_cc.cc \
+       test-ipcrypt_cc.cc \
        test-iputils_hh.cc \
        test-ixfr_cc.cc \
        test-lock_hh.cc \
        test-lua_auth4_cc.cc \
        test-misc_hh.cc \
+       test-mplexer.cc \
        test-nameserver_cc.cc \
        test-packetcache_cc.cc \
        test-packetcache_hh.cc \
@@ -1327,7 +1344,8 @@ testrunner_LDADD = \
        $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) \
        $(RT_LIBS) \
        $(LUA_LIBS) \
-       $(LIBDL)
+       $(LIBDL) \
+       $(IPCRYPT_LIBS)
 
 if PKCS11
 testrunner_SOURCES += pkcs11signers.cc pkcs11signers.hh
@@ -1344,6 +1362,25 @@ testrunner_SOURCES += decafsigners.cc
 testrunner_LDADD += $(LIBDECAF_LIBS)
 endif
 
+if HAVE_FREEBSD
+ixfrdist_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
+endif
+
+if HAVE_LINUX
+ixfrdist_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
+endif
+
+if HAVE_SOLARIS
+ixfrdist_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
+testrunner_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
+endif
+
 pdns_control_SOURCES = \
        arguments.cc \
        dynloader.cc \
@@ -1431,6 +1468,7 @@ fuzz_target_packetcache_SOURCES = \
        ednsoptions.cc ednsoptions.hh \
        misc.cc misc.hh \
        packetcache.hh \
+       qtype.cc qtype.hh \
        statbag.cc statbag.hh
 
 fuzz_target_packetcache_DEPENDENCIES = $(fuzz_targets_deps)
index 670f460fa7a70faddee014e0a49697a69fda1e7f..f8e67cb0772b0054693347ed7def0679c53630f3 100644 (file)
@@ -21,9 +21,9 @@ Install dependencies from Homebrew:
 brew install autoconf automake boost libedit libsodium libtool lua pkg-config protobuf
 ```
 
-Let configure know where to find libedit:
+Let configure know where to find libedit, and openssl or libressl:
 
 ```sh
-./configure 'PKG_CONFIG_PATH=/usr/local/opt/libedit/lib/pkgconfig'
+./configure 'PKG_CONFIG_PATH=/usr/local/opt/libedit/lib/pkgconfig:/usr/local/opt/libressl/lib/pkgconfig'
 make
 ```
index 0c541e608b399069bf72530c9566a54360d0b07e..95f850c8116dc2b0d9c164a33fd7950558859876 100644 (file)
@@ -224,7 +224,7 @@ bool GSQLBackend::setMaster(const DNSName &domain, const string &ip)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to set master of domain '"+domain.toLogString()+"': "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to set master of domain '"+domain.toLogString()+"' to IP address " + ip + ": "+e.txtReason());
   }
   return true;
 }
@@ -241,7 +241,7 @@ bool GSQLBackend::setKind(const DNSName &domain, const DomainInfo::DomainKind ki
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to set kind of domain '"+domain.toLogString()+"': "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to set kind of domain '"+domain.toLogString()+"' to " + toUpper(DomainInfo::getKindString(kind)) + ": "+e.txtReason());
   }
   return true;
 }
@@ -258,7 +258,7 @@ bool GSQLBackend::setAccount(const DNSName &domain, const string &account)
             reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to set account of domain '"+domain.toLogString()+"': "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to set account of domain '"+domain.toLogString()+"' to '" + account + "': "+e.txtReason());
   }
   return true;
 }
@@ -277,7 +277,7 @@ bool GSQLBackend::getDomainInfo(const DNSName &domain, DomainInfo &di, bool getS
       reset();
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to retrieve information about a domain: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to retrieve information about domain '" + domain.toLogString() + "': "+e.txtReason());
   }
 
   int numanswers=d_result.size();
@@ -447,7 +447,7 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
           reset();
       }
       catch(SSqlException &e) {
-        throw PDNSException("GSQLBackend unable to update ordername and auth for domain_id "+itoa(domain_id)+": "+e.txtReason());
+        throw PDNSException("GSQLBackend unable to update ordername and auth for " + qname.toLogString() + " for domain_id "+itoa(domain_id)+", domain name '" + qname.toLogString() + "': "+e.txtReason());
       }
     } else {
       try {
@@ -463,7 +463,7 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
           reset();
       }
       catch(SSqlException &e) {
-        throw PDNSException("GSQLBackend unable to update ordername and auth per type for domain_id "+itoa(domain_id)+": "+e.txtReason());
+        throw PDNSException("GSQLBackend unable to update ordername and auth for " + qname.toLogString() + "|" + QType(qtype).getName() + " for domain_id "+itoa(domain_id)+": "+e.txtReason());
       }
     }
   } else {
@@ -479,7 +479,7 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
           reset();
       }
       catch(SSqlException &e) {
-        throw PDNSException("GSQLBackend unable to nullify ordername and update auth for domain_id "+itoa(domain_id)+": "+e.txtReason());
+        throw PDNSException("GSQLBackend unable to nullify ordername and update auth for " + qname.toLogString() + " for domain_id "+itoa(domain_id)+": "+e.txtReason());
       }
     } else {
       try {
@@ -494,7 +494,7 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
           reset();
       }
       catch(SSqlException &e) {
-        throw PDNSException("GSQLBackend unable to nullify ordername and update auth per type for domain_id "+itoa(domain_id)+": "+e.txtReason());
+        throw PDNSException("GSQLBackend unable to nullify ordername and update auth for " + qname.toLogString() + "|" + QType(qtype).getName() + " for domain_id "+itoa(domain_id)+": "+e.txtReason());
       }
     }
   }
@@ -586,7 +586,7 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
     d_afterOrderQuery_stmt->reset();
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to find before/after (after) for domain_id "+itoa(id)+": "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to find before/after (after) for domain_id "+itoa(id)+" and qname '"+ qname.toLogString() +"': "+e.txtReason());
   }
 
   if(after.empty()) {
@@ -604,7 +604,7 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
       d_firstOrderQuery_stmt->reset();
     }
     catch(SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to find before/after (first) for domain_id "+itoa(id)+": "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to find before/after (first) for domain_id "+itoa(id)+" and qname '"+ qname.toLogString() + "': "+e.txtReason());
     }
   }
 
@@ -631,7 +631,7 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
       d_beforeOrderQuery_stmt->reset();
     }
     catch(SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to find before/after (before) for domain_id "+itoa(id)+": "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to find before/after (before) for domain_id "+itoa(id)+" and qname '"+ qname.toLogString() + ": "+e.txtReason());
     }
 
     if(! unhashed.empty())
@@ -659,7 +659,7 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
       d_lastOrderQuery_stmt->reset();
     }
     catch(SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to find before/after (last) for domain_id "+itoa(id)+": "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to find before/after (last) for domain_id "+itoa(id)+" and qname '"+ qname.toLogString() + ": "+e.txtReason());
     }
   } else {
     before=qname;
@@ -685,7 +685,7 @@ bool GSQLBackend::addDomainKey(const DNSName& name, const KeyData& key, int64_t&
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to store key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to store key for domain '"+ name.toLogString() + "': "+e.txtReason());
   }
 
   try {
@@ -726,7 +726,7 @@ bool GSQLBackend::activateDomainKey(const DNSName& name, unsigned int id)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to activate key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to activate key with id "+ std::to_string(id) + " for domain '" + name.toLogString() + "': "+e.txtReason());
   }
   return true;
 }
@@ -746,7 +746,7 @@ bool GSQLBackend::deactivateDomainKey(const DNSName& name, unsigned int id)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to deactivate key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to deactivate key with id "+ std::to_string(id) + " for domain '" + name.toLogString() + "': "+e.txtReason());
   }
   return true;
 }
@@ -766,7 +766,7 @@ bool GSQLBackend::removeDomainKey(const DNSName& name, unsigned int id)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to remove key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to remove key with id "+ std::to_string(id) + " for domain '" + name.toLogString() + "': "+e.txtReason());
   }
   return true;
 }
@@ -797,7 +797,7 @@ bool GSQLBackend::getTSIGKey(const DNSName& name, DNSName* algorithm, string* co
     d_getTSIGKeyQuery_stmt->reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to retrieve named TSIG key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to retrieve TSIG key with name '" + name.toLogString() + "': "+e.txtReason());
   }
 
   return !content->empty();
@@ -816,7 +816,7 @@ bool GSQLBackend::setTSIGKey(const DNSName& name, const DNSName& algorithm, cons
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to store named TSIG key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to store TSIG key with name '" + name.toLogString() + "' and algorithm '" + algorithm.toString() + "': "+e.txtReason());
   }
   return true;
 }
@@ -832,7 +832,7 @@ bool GSQLBackend::deleteTSIGKey(const DNSName& name)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to store named TSIG key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to delete TSIG key with name '" + name.toLogString() + "': "+e.txtReason());
   }
   return true;
 }
@@ -937,7 +937,7 @@ bool GSQLBackend::getAllDomainMetadata(const DNSName& name, std::map<std::string
     d_GetAllDomainMetadataQuery_stmt->reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to list metadata: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to list metadata for domain '" + name.toLogString() + "': "+e.txtReason());
   }
 
   return true;
@@ -968,7 +968,7 @@ bool GSQLBackend::getDomainMetadata(const DNSName& name, const std::string& kind
     d_GetDomainMetadataQuery_stmt->reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to list metadata: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to get metadata kind '" + kind + "' for domain '" + name.toLogString() + "': "+e.txtReason());
   }
 
   return true;
@@ -999,7 +999,7 @@ bool GSQLBackend::setDomainMetadata(const DNSName& name, const std::string& kind
     }
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to store metadata key: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to set metadata kind '" + kind + "' for domain '" + name.toLogString() + "': "+e.txtReason());
   }
   
   return true;
@@ -1045,7 +1045,7 @@ void GSQLBackend::lookup(const QType &qtype,const DNSName &qname, DNSPacket *pkt
       execute();
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend lookup query:"+e.txtReason());
+    throw PDNSException("GSQLBackend unable to lookup '" + qname.toLogString() + "|" + qtype.getName() + "':"+e.txtReason());
   }
 
   d_qname=qname;
@@ -1066,7 +1066,7 @@ bool GSQLBackend::list(const DNSName &target, int domain_id, bool include_disabl
       execute();
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend list query: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to list domain '" + target.toLogString() + "': "+e.txtReason());
   }
 
   d_qname.clear();
@@ -1089,7 +1089,7 @@ bool GSQLBackend::listSubZone(const DNSName &zone, int domain_id) {
       execute();      
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend listSubZone query: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to list SubZones for domain '" + zone.toLogString() + "': "+e.txtReason());
   }
   d_qname.clear();
   return true;
@@ -1140,7 +1140,7 @@ bool GSQLBackend::superMasterBackend(const string &ip, const DNSName &domain, co
         reset();
     }
     catch (SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to search for a domain: "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to search for a supermaster with IP " + ip + " and nameserver name '" + i->content + "' for domain '" + domain.toLogString() + "': "+e.txtReason());
     }
     if(!d_result.empty()) {
       ASSERT_ROW_COLUMNS("supermaster-query", d_result[0], 1);
@@ -1319,7 +1319,7 @@ bool GSQLBackend::replaceRRSet(uint32_t domain_id, const DNSName& qname, const Q
     }
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to delete RRSet: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to delete RRSet " + qname.toLogString() + "|" + qt.getName() + ": "+e.txtReason());
   }
 
   if (rrset.empty()) {
@@ -1334,7 +1334,7 @@ bool GSQLBackend::replaceRRSet(uint32_t domain_id, const DNSName& qname, const Q
         reset();
     }
     catch (SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to delete comment: "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to delete comment for RRSet " + qname.toLogString() + "|" + qt.getName() + ": "+e.txtReason());
     }
   }
   for(const auto& rr: rrset) {
@@ -1384,7 +1384,7 @@ bool GSQLBackend::feedRecord(const DNSResourceRecord &r, const DNSName &ordernam
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to feed record: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to feed record " + r.qname.toLogString() + "|" + r.qtype.getName() + ": "+e.txtReason());
   }
   return true; // XXX FIXME this API should not return 'true' I think -ahu 
 }
@@ -1404,7 +1404,7 @@ bool GSQLBackend::feedEnts(int domain_id, map<DNSName,bool>& nonterm)
         reset();
     }
     catch (SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to feed empty non-terminal: "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to feed empty non-terminal with name '" + nt.first.toLogString() + "': "+e.txtReason());
     }
   }
   return true;
@@ -1438,7 +1438,7 @@ bool GSQLBackend::feedEnts3(int domain_id, const DNSName &domain, map<DNSName,bo
         reset();
     }
     catch (SSqlException &e) {
-      throw PDNSException("GSQLBackend unable to feed empty non-terminal: "+e.txtReason());
+      throw PDNSException("GSQLBackend unable to feed empty non-terminal with name '" + nt.first.toLogString() + "' (hashed name '"+ toBase32Hex(hashQNameWithSalt(ns3prc, nt.first)) + "') : "+e.txtReason());
     }
   }
   return true;
@@ -1460,7 +1460,7 @@ bool GSQLBackend::startTransaction(const DNSName &domain, int domain_id)
   }
   catch (SSqlException &e) {
     d_inTransaction = false;
-    throw PDNSException("Database failed to start transaction: "+e.txtReason());
+    throw PDNSException("Database failed to start transaction for domain '" + domain.toLogString() + "': "+e.txtReason());
   }
 
   return true;
@@ -1504,7 +1504,7 @@ bool GSQLBackend::listComments(const uint32_t domain_id)
       execute();
   }
   catch(SSqlException &e) {
-    throw PDNSException("GSQLBackend list comments query: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to list comments for domain id " + std::to_string(domain_id) + ": "+e.txtReason());
   }
 
   return true;
@@ -1556,7 +1556,7 @@ void GSQLBackend::feedComment(const Comment& comment)
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to feed comment: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to feed comment for RRSet '" + comment.qname.toLogString() + "|" + comment.qtype.getName() + "': "+e.txtReason());
   }
 }
 
@@ -1573,7 +1573,7 @@ bool GSQLBackend::replaceComments(const uint32_t domain_id, const DNSName& qname
       reset();
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to delete comment: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to delete comment for RRSet '" + qname.toLogString() + "|" + qt.getName() + "': "+e.txtReason());
   }
 
   for(const auto& comment: comments) {
@@ -1606,7 +1606,7 @@ string GSQLBackend::directBackendCmd(const string &query)
    return out.str();
  }
  catch (SSqlException &e) {
-   throw PDNSException("GSQLBackend unable to execute query: "+e.txtReason());
+   throw PDNSException("GSQLBackend unable to execute direct command query '" + query + "': "+e.txtReason());
  }
 }
 
@@ -1623,9 +1623,8 @@ string GSQLBackend::pattern2SQLPattern(const string &pattern)
 bool GSQLBackend::searchRecords(const string &pattern, int maxResults, vector<DNSResourceRecord>& result)
 {
   d_qname.clear();
+  string escaped_pattern = pattern2SQLPattern(pattern);
   try {
-    string escaped_pattern = pattern2SQLPattern(pattern);
-
     reconnectIfNeeded();
 
     d_SearchRecordsQuery_stmt->
@@ -1653,7 +1652,7 @@ bool GSQLBackend::searchRecords(const string &pattern, int maxResults, vector<DN
     return true;
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to execute query: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to search for records with pattern '" + pattern + "' (escaped pattern '" + escaped_pattern + "'): "+e.txtReason());
   }
 
   return false;
@@ -1662,9 +1661,8 @@ bool GSQLBackend::searchRecords(const string &pattern, int maxResults, vector<DN
 bool GSQLBackend::searchComments(const string &pattern, int maxResults, vector<Comment>& result)
 {
   Comment c;
+  string escaped_pattern = pattern2SQLPattern(pattern);
   try {
-    string escaped_pattern = pattern2SQLPattern(pattern);
-
     reconnectIfNeeded();
 
     d_SearchCommentsQuery_stmt->
@@ -1687,7 +1685,7 @@ bool GSQLBackend::searchComments(const string &pattern, int maxResults, vector<C
     return true;
   }
   catch (SSqlException &e) {
-    throw PDNSException("GSQLBackend unable to execute query: "+e.txtReason());
+    throw PDNSException("GSQLBackend unable to search for comments with pattern '" + pattern + "' (escaped pattern '" + escaped_pattern + "'): "+e.txtReason());
   }
 
   return false;
index a82a4fc03d94c851e43b4a739c1a7b7bbbdbccda..f2c8bb0a302f02eca6dba6151ce3bfba964721aa 100644 (file)
@@ -377,7 +377,7 @@ try
 
     DNSPacketWriter pw(packet, DNSName(qname), DNSRecordContent::TypeToNumber(qtype));
     pw.getHeader()->rd=wantRecursion;
-    pw.getHeader()->id=random();
+    pw.getHeader()->id=dns_random(UINT16_MAX);
 
     if(!subnet.empty() || !ecsRange.empty()) {
       EDNSSubnetOpts opt;
@@ -454,7 +454,7 @@ try
       known.push_back(ptr);
     }
     for(;n < total; ++n) {
-      toSend.push_back(known[random()%known.size()].get());
+      toSend.push_back(known[dns_random(known.size())].get());
     }
     random_shuffle(toSend.begin(), toSend.end());
     g_recvcounter.store(0);
index e6c487169aaafa32ab024f3e32ac4c42ea2e1165..e7f6f5ab95255b3ebe17a214c2420ae131aa0634 100644 (file)
@@ -150,6 +150,7 @@ void declareArguments()
   ::arg().set("webserver-port","Port of webserver/API to listen on")="8081";
   ::arg().set("webserver-password","Password required for accessing the webserver")="";
   ::arg().set("webserver-allow-from","Webserver/API access is only allowed from these subnets")="127.0.0.1,::1";
+  ::arg().set("webserver-loglevel", "Amount of logging in the webserver (none, normal, detailed)") = "normal";
 
   ::arg().setSwitch("do-ipv6-additional-processing", "Do AAAA additional processing")="yes";
   ::arg().setSwitch("query-logging","Hint backends that queries should be logged")="no";
@@ -336,11 +337,11 @@ void declareStats(void)
   S.declare("latency","Average number of microseconds needed to answer a question", getLatency);
   S.declare("timedout-packets","Number of packets which weren't answered within timeout set");
   S.declare("security-status", "Security status based on regular polling");
-  S.declareRing("queries","UDP Queries Received");
-  S.declareRing("nxdomain-queries","Queries for non-existent records within existent domains");
-  S.declareRing("noerror-queries","Queries for existing records, but for type we don't have");
-  S.declareRing("servfail-queries","Queries that could not be answered due to backend errors");
-  S.declareRing("unauth-queries","Queries for domains that we are not authoritative for");
+  S.declareDNSNameQTypeRing("queries","UDP Queries Received");
+  S.declareDNSNameQTypeRing("nxdomain-queries","Queries for non-existent records within existent domains");
+  S.declareDNSNameQTypeRing("noerror-queries","Queries for existing records, but for type we don't have");
+  S.declareDNSNameQTypeRing("servfail-queries","Queries that could not be answered due to backend errors");
+  S.declareDNSNameQTypeRing("unauth-queries","Queries for domains that we are not authoritative for");
   S.declareRing("logmessages","Log Messages");
   S.declareComboRing("remotes","Remote server IP addresses");
   S.declareComboRing("remotes-unauth","Remote hosts querying domains for which we are not auth");
@@ -422,7 +423,7 @@ try
      if(P->d.qr)
        continue;
 
-    S.ringAccount("queries", P->qdomain.toLogString()+"/"+P->qtype.getName());
+    S.ringAccount("queries", P->qdomain, P->qtype);
     S.ringAccount("remotes",P->d_remote);
     if(logDNSQueries) {
       string remote;
@@ -497,9 +498,9 @@ static void triggerLoadOfLibraries()
 
 void mainthread()
 {
-   Utility::srandom(time(0) ^ getpid());
+   Utility::srandom();
 
-   int newgid=0;      
+   int newgid=0;
    if(!::arg()["setgid"].empty()) 
      newgid=Utility::makeGidNumeric(::arg()["setgid"]);      
    int newuid=0;      
index 2f63af582483931ffd4ad9e3c65b11b55464b848..e4aa0b52821217f6dc738ce08c1de4206e2bafb8 100644 (file)
@@ -692,7 +692,7 @@ bool DNSSECKeeper::rectifyZone(const DNSName& zone, string& error, string& info,
   }
 
   set<DNSName> nsec3set;
-  if (haveNSEC3 && !narrow) {
+  if (haveNSEC3 && (!narrow || !isOptOut)) {
     for (auto &loopRR: rrs) {
       bool skip=false;
       DNSName shorter = loopRR.qname;
@@ -743,12 +743,12 @@ bool DNSSECKeeper::rectifyZone(const DNSName& zone, string& error, string& info,
 
     if(haveNSEC3) // NSEC3
     {
-      if(!narrow && nsec3set.count(qname)) {
-        ordername=DNSName(toBase32Hex(hashQNameWithSalt(ns3pr, qname)));
-        if(!realrr)
+      if(nsec3set.count(qname)) {
+        if(!narrow)
+          ordername=DNSName(toBase32Hex(hashQNameWithSalt(ns3pr, qname)));
+        if(!realrr && !isOptOut)
           auth=true;
-      } else if(!realrr)
-        auth=false;
+      }
     }
     else if (realrr && securedZone) // NSEC
       ordername=qname.makeRelative(zone);
index f9234965a8f9c56a9dc218ecd847cc70a1c6b3ff..35df6dc8fc9940f5e498d409a4cab5659da1ce6d 100644 (file)
@@ -49,7 +49,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -82,9 +82,9 @@ DevPollFDMultiplexer::DevPollFDMultiplexer()
     
 }
 
-void DevPollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+void DevPollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct pollfd devent;
   devent.fd=fd;
@@ -160,13 +160,13 @@ int DevPollFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(dvp.dp_fds[n].fd);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't refind ourselves as writable!
     }
     d_iter=d_writeCallbacks.find(dvp.dp_fds[n].fd);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
   delete[] dvp.dp_fds;
index aa53f64a11050ef3458b5c4f41dff856fb8f1c45..ecc1509cd49525ca92f663e1c764e17845039642 100644 (file)
@@ -322,13 +322,12 @@ template<class Answer, class Question, class Backend>int MultiThreadDistributor<
   QD->Q=q;
   auto ret = QD->id = nextid++; // might be deleted after write!
   QD->callback=callback;
-  
-  if(write(d_pipes[QD->id % d_pipes.size()].second, &QD, sizeof(QD)) != sizeof(QD))
-    unixDie("write");
-
-  d_queued++;
-
 
+  ++d_queued;
+  if(write(d_pipes[QD->id % d_pipes.size()].second, &QD, sizeof(QD)) != sizeof(QD)) {
+    --d_queued;
+    unixDie("write");
+  }
 
   if(d_queued > d_maxQueueLength) {
     g_log<<Logger::Error<< d_queued <<" questions waiting for database/backend attention. Limit is "<<::arg().asNum("max-queue-length")<<", respawning"<<endl;
index 40876017cd5527f2b30227c0ce09d420f0332fb6..307784cb8092ee5fcaa3ca3c42d003cbba0cb382 100644 (file)
 
 #ifndef HAVE_DNSCRYPT
 
+/* let's just define a few types and values so that the rest of
+   the code can ignore whether DNSCrypt support is available */
+#define DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE (0)
+
+class DNSCryptContext
+{
+};
+
 class DNSCryptQuery
 {
+  DNSCryptQuery(const std::shared_ptr<DNSCryptContext>& ctx): d_ctx(ctx)
+  {
+  }
+private:
+  std::shared_ptr<DNSCryptContext> d_ctx{nullptr};
 };
 
-#else
+#else /* HAVE_DNSCRYPT */
 
 #include <memory>
 #include <string>
index 8c6f86f7cd59740f79929b5813b16525f155efa0..2736dc0480ce441d59bf0d1e3facc3b681a2a15d 100644 (file)
@@ -233,7 +233,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t
     }
 
     const CacheValue& value = it->second;
-    if (value.validity < now) {
+    if (value.validity <= now) {
       if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
         d_misses++;
         return false;
@@ -292,12 +292,13 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t
 /* Remove expired entries, until the cache has at most
    upTo entries in it.
 */
-void DNSDistPacketCache::purgeExpired(size_t upTo)
+size_t DNSDistPacketCache::purgeExpired(size_t upTo)
 {
+  size_t removed = 0;
   uint64_t size = getSize();
 
   if (size == 0 || upTo >= size) {
-    return;
+    return removed;
   }
 
   size_t toRemove = size - upTo;
@@ -313,10 +314,11 @@ void DNSDistPacketCache::purgeExpired(size_t upTo)
     for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
       const CacheValue& value = it->second;
 
-      if (value.validity < now) {
+      if (value.validity <= now) {
         it = map.erase(it);
         --toRemove;
         d_shards[shardIndex].d_entriesCount--;
+        ++removed;
       } else {
         ++it;
       }
@@ -325,20 +327,22 @@ void DNSDistPacketCache::purgeExpired(size_t upTo)
     scannedMaps++;
   }
   while (toRemove > 0 && scannedMaps < d_shardCount);
+
+  return removed;
 }
 
 /* Remove all entries, keeping only upTo
    entries in the cache */
-void DNSDistPacketCache::expunge(size_t upTo)
+size_t DNSDistPacketCache::expunge(size_t upTo)
 {
+  size_t removed = 0;
   const uint64_t size = getSize();
 
   if (upTo >= size) {
-    return;
+    return removed;
   }
 
   size_t toRemove = size - upTo;
-  size_t removed = 0;
 
   for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
     WriteLock w(&d_shards.at(shardIndex).d_lock);
@@ -358,10 +362,14 @@ void DNSDistPacketCache::expunge(size_t upTo)
       d_shards[shardIndex].d_entriesCount = 0;
     }
   }
+
+  return removed;
 }
 
-void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
+size_t DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
 {
+  size_t removed = 0;
+
   for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
     WriteLock w(&d_shards.at(shardIndex).d_lock);
     auto& map = d_shards[shardIndex].d_map;
@@ -372,11 +380,14 @@ void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool
       if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
         it = map.erase(it);
         d_shards[shardIndex].d_entriesCount--;
+        ++removed;
       } else {
         ++it;
       }
     }
   }
+
+  return removed;
 }
 
 bool DNSDistPacketCache::isFull()
index 830a01609ae83d3ca3732d9a24f71643995c9cd4..887f03d32bf63d01158cb1c9a0192f495778241a 100644 (file)
@@ -37,9 +37,9 @@ public:
 
   void insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL);
   bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional<Netmask>& subnetOut, bool dnssecOK, uint32_t allowExpired=0, bool skipAging=false);
-  void purgeExpired(size_t upTo=0);
-  void expunge(size_t upTo=0);
-  void expungeByName(const DNSName& name, uint16_t qtype=QType::ANY, bool suffixMatch=false);
+  size_t purgeExpired(size_t upTo=0);
+  size_t expunge(size_t upTo=0);
+  size_t expungeByName(const DNSName& name, uint16_t qtype=QType::ANY, bool suffixMatch=false);
   bool isFull();
   string toString();
   uint64_t getSize();
index f8f4d943d881175c4a20446891cad90d21c46000..1fa49f58468010255209bff4ab789ac568a073d8 100644 (file)
@@ -95,6 +95,14 @@ try
           str<<base<<"latency" << ' ' << (state->availability != DownstreamState::Availability::Down ? state->latencyUsec/1000.0 : 0) << " " << now << "\r\n";
           str<<base<<"senderrors" << ' ' << state->sendErrors.load() << " " << now << "\r\n";
           str<<base<<"outstanding" << ' ' << state->outstanding.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedsendingquery" << ' '<< state->tcpDiedSendingQuery.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedreaddingresponse" << ' '<< state->tcpDiedReadingResponse.load() << " " << now << "\r\n";
+          str<<base<<"tcpgaveup" << ' '<< state->tcpGaveUp.load() << " " << now << "\r\n";
+          str<<base<<"tcpreadimeouts" << ' '<< state->tcpReadTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpwritetimeouts" << ' '<< state->tcpWriteTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpcurrentconnections" << ' '<< state->tcpCurrentConnections.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgqueriesperconnection" << ' '<< state->tcpAvgQueriesPerConnection.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgconnectionduration" << ' '<< state->tcpAvgConnectionDuration.load() << " " << now << "\r\n";
         }
         for(const auto& front : g_frontends) {
           if (front->udpFD == -1 && front->tcpFD == -1)
@@ -104,6 +112,14 @@ try
           boost::replace_all(frontName, ".", "_");
           const string base = namespace_name + "." + hostname + "." + instance_name + ".frontends." + frontName + ".";
           str<<base<<"queries" << ' ' << front->queries.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedreadingquery" << ' '<< front->tcpDiedReadingQuery.load() << " " << now << "\r\n";
+          str<<base<<"tcpdiedsendingresponse" << ' '<< front->tcpDiedSendingResponse.load() << " " << now << "\r\n";
+          str<<base<<"tcpgaveup" << ' '<< front->tcpGaveUp.load() << " " << now << "\r\n";
+          str<<base<<"tcpclientimeouts" << ' '<< front->tcpClientTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpdownstreamtimeouts" << ' '<< front->tcpDownstreamTimeouts.load() << " " << now << "\r\n";
+          str<<base<<"tcpcurrentconnections" << ' '<< front->tcpCurrentConnections.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgqueriesperconnection" << ' '<< front->tcpAvgQueriesPerConnection.load() << " " << now << "\r\n";
+          str<<base<<"tcpavgconnectionduration" << ' '<< front->tcpAvgConnectionDuration.load() << " " << now << "\r\n";
         }
         auto localPools = g_pools.getLocal();
         for (const auto& entry : *localPools) {
index 87e6b10ef985ff6e19e957cb7473d35652cd732a..3d0dd4d039f02ff307296a819116b7bf63f5a9c8 100644 (file)
@@ -347,8 +347,6 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "addDNSCryptBind", true, "\"127.0.0.1:8443\", \"provider name\", \"/path/to/resolver.cert\", \"/path/to/resolver.key\", {reusePort=false, tcpFastOpenSize=0, interface=\"\", cpus={}}", "listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of `provider name`, using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files. The fifth optional parameter is a table of parameters" },
   { "addDynBlocks", true, "addresses, message[, seconds[, action]]", "block the set of addresses with message `msg`, for `seconds` seconds (10 by default), applying `action` (default to the one set with `setDynBlocksAction()`)" },
   { "addLocal", true, "addr [, {doTCP=true, reusePort=false, tcpFastOpenSize=0, interface=\"\", cpus={}}]", "add `addr` to the list of addresses we listen on" },
-  { "addLuaAction", true, "x, func [, {uuid=\"UUID\"}]", "where 'x' is all the combinations from `addAction`, and func is a function with the parameter `dq`, which returns an action to be taken on this packet. Good for rare packets but where you want to do a lot of processing" },
-  { "addLuaResponseAction", true, "x, func [, {uuid=\"UUID\"}]", "where 'x' is all the combinations from `addAction`, and func is a function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing" },
   { "addCacheHitResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a cache hit response rule" },
   { "addResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a response rule" },
   { "addSelfAnsweredResponseAction", true, "DNS rule, DNS response action [, {uuid=\"UUID\"}]", "add a self-answered response rule" },
@@ -474,6 +472,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "setServerPolicyLua", true, "name, function", "set server selection policy to one named 'name' and provided by 'function'" },
   { "setServFailWhenNoServer", true, "bool", "if set, return a ServFail when no servers are available, instead of the default behaviour of dropping the query" },
   { "setStaleCacheEntriesTTL", true, "n", "allows using cache entries expired for at most n seconds when there is no backend available to answer for a query" },
+  { "setSyslogFacility", true, "facility", "set the syslog logging facility to 'facility'. Defaults to LOG_DAEMON" },
   { "setTCPDownstreamCleanupInterval", true, "interval", "minimum interval in seconds between two cleanups of the idle TCP downstream connections" },
   { "setTCPUseSinglePipe", true, "bool", "whether the incoming TCP connections should be put into a single queue instead of using per-thread queues. Defaults to false" },
   { "setTCPRecvTimeout", true, "n", "set the read timeout on TCP connections from the client, in seconds" },
index 73d4af36d82a25e5d1fee16bdcccf24354704240..5d9923fb07efafbd5b188d6fd767b53ce6b4a8ce 100644 (file)
  */
 #pragma once
 
+#include <unordered_set>
+
 #include "dolog.hh"
 #include "dnsdist-rings.hh"
+#include "statnode.hh"
+
+#include "dnsdist-lua-inspection-ffi.hh"
+
+// dnsdist_ffi_stat_node_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_stat_node_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_stat_node_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+typedef std::function<bool(dnsdist_ffi_stat_node_t*)> dnsdist_ffi_stat_node_visitor_t;
+
+struct dnsdist_ffi_stat_node_t
+{
+  dnsdist_ffi_stat_node_t(const StatNode& node_, const StatNode::Stat& self_, const StatNode::Stat& children_): node(node_), self(self_), children(children_)
+  {
+  }
+
+  const StatNode& node;
+  const StatNode::Stat& self;
+  const StatNode::Stat& children;
+};
 
 class DynBlockRulesGroup
 {
@@ -119,7 +149,7 @@ private:
     bool d_enabled{false};
   };
 
-    typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
+  typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
 
 public:
   DynBlockRulesGroup()
@@ -149,6 +179,20 @@ public:
     entry = DynBlockRule(reason, blockDuration, rate, warningRate, seconds, action);
   }
 
+  typedef std::function<bool(const StatNode&, const StatNode::Stat&, const StatNode::Stat&)> smtVisitor_t;
+
+  void setSuffixMatchRule(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, smtVisitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitor = visitor;
+  }
+
+  void setSuffixMatchRuleFFI(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, dnsdist_ffi_stat_node_visitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitorFFI = visitor;
+  }
+
   void apply()
   {
     struct timespec now;
@@ -160,6 +204,7 @@ public:
   void apply(const struct timespec& now)
   {
     counts_t counts;
+    StatNode statNodeRoot;
 
     size_t entriesCount = 0;
     if (hasQueryRules()) {
@@ -171,9 +216,9 @@ public:
     counts.reserve(entriesCount);
 
     processQueryRules(counts, now);
-    processResponseRules(counts, now);
+    processResponseRules(counts, statNodeRoot, now);
 
-    if (counts.empty()) {
+    if (counts.empty() && statNodeRoot.empty()) {
       return;
     }
 
@@ -239,6 +284,38 @@ public:
     if (updated && blocks) {
       g_dynblockNMG.setState(*blocks);
     }
+
+    if (!statNodeRoot.empty()) {
+      StatNode::Stat node;
+      std::unordered_set<DNSName> namesToBlock;
+      statNodeRoot.visit([this,&namesToBlock](const StatNode* node_, const StatNode::Stat& self, const StatNode::Stat& children) {
+                           bool block = false;
+
+                           if (d_smtVisitorFFI) {
+                             dnsdist_ffi_stat_node_t tmp(*node_, self, children);
+                             block = d_smtVisitorFFI(&tmp);
+                           }
+                           else {
+                             block = d_smtVisitor(*node_, self, children);
+                           }
+
+                           if (block) {
+                             namesToBlock.insert(DNSName(node_->fullname));
+                           }
+                         },
+        node);
+
+      if (!namesToBlock.empty()) {
+        updated = false;
+        SuffixMatchTree<DynBlock> smtBlocks = g_dynblockSMT.getCopy();
+        for (const auto& name : namesToBlock) {
+          addOrRefreshBlockSMT(smtBlocks, now, name, d_suffixMatchRule, updated);
+        }
+        if (updated) {
+          g_dynblockSMT.setState(smtBlocks);
+        }
+      }
+    }
   }
 
   void excludeRange(const Netmask& range)
@@ -251,12 +328,18 @@ public:
     d_excludedSubnets.addMask(range, false);
   }
 
+  void excludeDomain(const DNSName& domain)
+  {
+    d_excludedDomains.add(domain);
+  }
+
   std::string toString() const
   {
     std::stringstream result;
 
     result << "Query rate rule: " << d_queryRateRule.toString() << std::endl;
     result << "Response rate rule: " << d_respRateRule.toString() << std::endl;
+    result << "SuffixMatch rule: " << d_suffixMatchRule.toString() << std::endl;
     result << "RCode rules: " << std::endl;
     for (const auto& rule : d_rcodeRules) {
       result << "- " << RCode::to_s(rule.first) << ": " << rule.second.toString() << std::endl;
@@ -266,6 +349,7 @@ public:
       result << "- " << QType(rule.first).getName() << ": " << rule.second.toString() << std::endl;
     }
     result << "Excluded Subnets: " << d_excludedSubnets.toString() << std::endl;
+    result << "Excluded Domains: " << d_excludedDomains.toString() << std::endl;
 
     return result.str();
   }
@@ -348,6 +432,44 @@ private:
     updated = true;
   }
 
+  void addOrRefreshBlockSMT(SuffixMatchTree<DynBlock>& blocks, const struct timespec& now, const DNSName& name, const DynBlockRule& rule, bool& updated)
+  {
+    if (d_excludedDomains.check(name)) {
+      /* do not add a block for excluded domains */
+      return;
+    }
+
+    struct timespec until = now;
+    until.tv_sec += rule.d_blockDuration;
+    unsigned int count = 0;
+    const auto& got = blocks.lookup(name);
+    bool expired = false;
+
+    if (got) {
+      if (until < got->until) {
+        // had a longer policy
+        return;
+      }
+
+      if (now < got->until) {
+        // only inherit count on fresh query we are extending
+        count = got->blocks;
+      }
+      else {
+        expired = true;
+      }
+    }
+
+    DynBlock db{rule.d_blockReason, until, name, rule.d_action};
+    db.blocks = count;
+
+    if (!d_beQuiet && (!got || expired)) {
+      warnlog("Inserting dynamic block for %s for %d seconds: %s", name, rule.d_blockDuration, rule.d_blockReason);
+    }
+    blocks.add(name, db);
+    updated = true;
+  }
+
   void addBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const ComboAddress& requestor, const DynBlockRule& rule, bool& updated)
   {
     addOrRefreshBlock(blocks, now, requestor, rule, updated, false);
@@ -368,6 +490,11 @@ private:
     return d_respRateRule.isEnabled() || !d_rcodeRules.empty();
   }
 
+  bool hasSuffixMatchRules() const
+  {
+    return d_suffixMatchRule.isEnabled();
+  }
+
   bool hasRules() const
   {
     return hasQueryRules() || hasResponseRules();
@@ -410,15 +537,18 @@ private:
     }
   }
 
-  void processResponseRules(counts_t& counts, const struct timespec& now)
+  void processResponseRules(counts_t& counts, StatNode& root, const struct timespec& now)
   {
-    if (!hasResponseRules()) {
+    if (!hasResponseRules() && !hasSuffixMatchRules()) {
       return;
     }
 
     d_respRateRule.d_cutOff = d_respRateRule.d_minTime = now;
     d_respRateRule.d_cutOff.tv_sec -= d_respRateRule.d_seconds;
 
+    d_suffixMatchRule.d_cutOff = d_suffixMatchRule.d_minTime = now;
+    d_suffixMatchRule.d_cutOff.tv_sec -= d_suffixMatchRule.d_seconds;
+
     for (auto& rule : d_rcodeRules) {
       rule.second.d_cutOff = rule.second.d_minTime = now;
       rule.second.d_cutOff.tv_sec -= rule.second.d_seconds;
@@ -432,6 +562,7 @@ private:
         }
 
         bool respRateMatches = d_respRateRule.matches(c.when);
+        bool suffixMatchRuleMatches = d_suffixMatchRule.matches(c.when);
         bool rcodeRuleMatches = checkIfResponseCodeMatches(c);
 
         if (respRateMatches || rcodeRuleMatches) {
@@ -443,6 +574,10 @@ private:
             entry.d_rcodeCounts[c.dh.rcode]++;
           }
         }
+
+        if (suffixMatchRuleMatches) {
+          root.submit(c.name, c.dh.rcode, boost::none);
+        }
       }
     }
   }
@@ -451,6 +586,10 @@ private:
   std::map<uint16_t, DynBlockRule> d_qtypeRules;
   DynBlockRule d_queryRateRule;
   DynBlockRule d_respRateRule;
+  DynBlockRule d_suffixMatchRule;
   NetmaskGroup d_excludedSubnets;
+  SuffixMatchNode d_excludedDomains;
+  smtVisitor_t d_smtVisitor;
+  dnsdist_ffi_stat_node_visitor_t d_smtVisitorFFI;
   bool d_beQuiet{false};
 };
index 56422e48a8a992ca5cceb60be62598c072924144..5e8974d6983c9d9fdcd47aac8c78ee1b1273cba5 100644 (file)
@@ -257,10 +257,10 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
-  dh.d_clen = htons((uint16_t) optRData.length());
+  dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
   res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
-  res.assign((const char *) &name, sizeof name);
-  res.append((const char *) &dh, sizeof dh);
+  res.assign(reinterpret_cast<const char *>(&name), sizeof name);
+  res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
   res.append(optRData.c_str(), optRData.length());
 }
 
@@ -464,9 +464,7 @@ static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16
 
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
 {
-  /* we need at least:
-     root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (*optLen < 11) {
+  if (*optLen < optRecordMinimumSize) {
     return EINVAL;
   }
   const unsigned char* end = (const unsigned char*) optStart + *optLen;
@@ -490,15 +488,13 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
 
 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
 {
-  /* we need at least:
-   root label (1), type (2), class (2), ttl (4) + rdlen (2)*/
-  if (optLen < 11) {
+  if (optLen < optRecordMinimumSize) {
     return false;
   }
   size_t p = optStart + 9;
   uint16_t rdLen = (0x100*packet.at(p) + packet.at(p+1));
   p += sizeof(rdLen);
-  if (rdLen > (optLen - 11)) {
+  if (rdLen > (optLen - optRecordMinimumSize)) {
     return false;
   }
 
@@ -741,3 +737,31 @@ bool queryHasEDNS(const DNSQuestion& dq)
 
   return false;
 }
+
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
+{
+  uint16_t optStart;
+  size_t optLen = 0;
+  bool last = false;
+  const char * packet = reinterpret_cast<const char*>(dq.dh);
+  std::string packetStr(packet, dq.len);
+  int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+  if (res != 0) {
+    // no EDNS OPT RR
+    return false;
+  }
+
+  if (optLen < optRecordMinimumSize) {
+    return false;
+  }
+
+  if (optStart < dq.len && packetStr.at(optStart) != 0) {
+    // OPT RR Name != '.'
+    return false;
+  }
+
+  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
+  memcpy(&edns0, packet + optStart + 5, sizeof edns0);
+  return true;
+}
index 7c3739f443c2546c46101c4935169c9207c48496..767575723f059ca4320dfb89c88a055fb37fb300 100644 (file)
@@ -21,6 +21,9 @@
  */
 #pragma once
 
+// root label (1), type (2), class (2), ttl (4) + rdlen (2)
+static const size_t optRecordMinimumSize = 11;
+
 extern size_t g_EdnsUDPPayloadSize;
 extern uint16_t g_PayloadSizeSelfGenAnswers;
 
@@ -42,3 +45,4 @@ bool parseEDNSOptions(DNSQuestion& dq);
 
 int getEDNSZ(const DNSQuestion& dq);
 bool queryHasEDNS(const DNSQuestion& dq);
+bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0);
index 55c7251e0aaca0866ec83c78e1ffc17cdf89c5bf..4f255b6e8c0766ec2f8338cfcac9b657af856d7b 100644 (file)
@@ -19,6 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+#include "config.h"
 #include "threadname.hh"
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "ednsoptions.hh"
 #include "fstrm_logger.hh"
 #include "remote_logger.hh"
-#include "boost/optional/optional_io.hpp"
+
+#include <boost/optional/optional_io.hpp>
+
+#ifdef HAVE_LIBCRYPTO
+#include "ipcipher.hh"
+#endif /* HAVE_LIBCRYPTO */
 
 class DropAction : public DNSAction
 {
@@ -771,7 +777,7 @@ private:
 class RemoteLogAction : public DNSAction, public boost::noncopyable
 {
 public:
-  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID)
+  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey)
   {
   }
   DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
@@ -786,6 +792,13 @@ public:
       message.setServerIdentity(d_serverID);
     }
 
+#if HAVE_LIBCRYPTO
+    if (!d_ipEncryptKey.empty())
+    {
+      message.setRequestor(encryptCA(*dq->remote, d_ipEncryptKey));
+    }
+#endif /* HAVE_LIBCRYPTO */
+
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
       (*d_alterFunc)(*dq, &message);
@@ -805,6 +818,7 @@ private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
   boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
+  std::string d_ipEncryptKey;
 };
 
 class SNMPTrapAction : public DNSAction
@@ -891,7 +905,7 @@ private:
 class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
 {
 public:
-  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_includeCNAME(includeCNAME)
+  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME)
   {
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
@@ -906,6 +920,13 @@ public:
       message.setServerIdentity(d_serverID);
     }
 
+#if HAVE_LIBCRYPTO
+    if (!d_ipEncryptKey.empty())
+    {
+      message.setRequestor(encryptCA(*dr->remote, d_ipEncryptKey));
+    }
+#endif /* HAVE_LIBCRYPTO */
+
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
       (*d_alterFunc)(*dr, &message);
@@ -925,6 +946,7 @@ private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
   boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
+  std::string d_ipEncryptKey;
   bool d_includeCNAME;
 };
 
@@ -1053,14 +1075,6 @@ void setupLuaActions()
       addAction(&g_rulactions, var, boost::get<std::shared_ptr<DNSAction> >(era), params);
     });
 
-  g_lua.writeFunction("addLuaAction", [](luadnsrule_t var, LuaAction::func_t func, boost::optional<luaruleparams_t> params) {
-      addAction(&g_rulactions, var, std::make_shared<LuaAction>(func), params);
-    });
-
-  g_lua.writeFunction("addLuaResponseAction", [](luadnsrule_t var, LuaResponseAction::func_t func, boost::optional<luaruleparams_t> params) {
-      addAction(&g_resprulactions, var, std::make_shared<LuaResponseAction>(func), params);
-    });
-
   g_lua.writeFunction("addResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) {
       if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
         throw std::runtime_error("addResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
@@ -1221,14 +1235,18 @@ void setupLuaActions()
       }
 
       std::string serverID;
+      std::string ipEncryptKey;
       if (vars) {
         if (vars->count("serverID")) {
           serverID = boost::get<std::string>((*vars)["serverID"]);
         }
+        if (vars->count("ipEncryptKey")) {
+          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
+        }
       }
 
 #ifdef HAVE_PROTOBUF
-      return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID));
+      return std::shared_ptr<DNSAction>(new RemoteLogAction(logger, alterFunc, serverID, ipEncryptKey));
 #else
       throw std::runtime_error("Protobuf support is required to use RemoteLogAction");
 #endif
@@ -1243,14 +1261,18 @@ void setupLuaActions()
       }
 
       std::string serverID;
+      std::string ipEncryptKey;
       if (vars) {
         if (vars->count("serverID")) {
           serverID = boost::get<std::string>((*vars)["serverID"]);
         }
+        if (vars->count("ipEncryptKey")) {
+          ipEncryptKey = boost::get<std::string>((*vars)["ipEncryptKey"]);
+        }
       }
 
 #ifdef HAVE_PROTOBUF
-      return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, includeCNAME ? *includeCNAME : false));
+      return std::shared_ptr<DNSResponseAction>(new RemoteLogResponseAction(logger, alterFunc, serverID, ipEncryptKey, includeCNAME ? *includeCNAME : false));
 #else
       throw std::runtime_error("Protobuf support is required to use RemoteLogResponseAction");
 #endif
index 79e01d2e13792acb0162eab0757f63caea9d4115..6a8d64156de8cde7eb8534ae88112b08bbfa9fc0 100644 (file)
@@ -23,6 +23,7 @@
 #include <sys/stat.h>
 #include <sys/types.h>
 
+#include "config.h"
 #include "dnsdist.hh"
 #include "dnsdist-lua.hh"
 #include "dnsdist-protobuf.hh"
 #include "fstrm_logger.hh"
 #include "remote_logger.hh"
 
+#ifdef HAVE_LIBCRYPTO
+#include "ipcipher.hh"
+#endif /* HAVE_LIBCRYPTO */
+
 void setupLuaBindings(bool client)
 {
   g_lua.writeFunction("infolog", [](const string& arg) {
@@ -166,6 +171,19 @@ void setupLuaBindings(bool client)
   g_lua.registerFunction<ComboAddress(ComboAddress::*)()>("mapToIPv4", [](const ComboAddress& ca) { return ca.mapToIPv4(); });
   g_lua.registerFunction<bool(nmts_t::*)(const ComboAddress&)>("match", [](nmts_t& s, const ComboAddress& ca) { return s.match(ca); });
 
+#ifdef HAVE_LIBCRYPTO
+  g_lua.registerFunction<ComboAddress(ComboAddress::*)(const std::string& key)>("ipencrypt", [](const ComboAddress& ca, const std::string& key) {
+      return encryptCA(ca, key);
+    });
+  g_lua.registerFunction<ComboAddress(ComboAddress::*)(const std::string& key)>("ipdecrypt", [](const ComboAddress& ca, const std::string& key) {
+      return decryptCA(ca, key);
+    });
+
+  g_lua.writeFunction("makeIPCipherKey", [](const std::string& password) {
+      return makeIPCipherKey(password);
+    });
+#endif /* HAVE_LIBCRYPTO */
+
   /* DNSName */
   g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
   g_lua.registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });
@@ -659,10 +677,10 @@ void setupLuaBindings(bool client)
   g_lua.registerFunction<size_t(EDNSOptionView::*)()>("count", [](const EDNSOptionView& option) {
       return option.values.size();
     });
-  g_lua.registerFunction<std::vector<std::pair<int, string>>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
-    std::vector<std::pair<int, string> > values;
+  g_lua.registerFunction<std::vector<string>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
+    std::vector<string> values;
     for (const auto& value : option.values) {
-      values.push_back(std::make_pair(values.size(), std::string(value.content, value.size)));
+      values.push_back(std::string(value.content, value.size));
     }
     return values;
   });
index c64bf563f1bf4c5dab3ccace2b30b09d7acc1e2f..25d9a187d3acd45da71d1dea8e67a797e44bfe26 100644 (file)
@@ -552,10 +552,38 @@ void setupLuaInspection()
 
   g_lua.writeFunction("showTCPStats", [] {
       setLuaNoSideEffect();
+      ostringstream ret;
       boost::format fmt("%-10d %-10d %-10d %-10d\n");
-      g_outputBuffer += (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued").str();
-      g_outputBuffer += (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections).str();
-      g_outputBuffer += "Query distribution mode is: " + std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") + "\n";
+      ret << (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued") << endl;
+      ret << (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections) << endl;
+      ret <<endl;
+
+      ret << "Query distribution mode is: " << std::string(g_useTCPSinglePipe ? "single queue" : "per-thread queues") << endl;
+      ret << endl;
+
+      ret << "Frontends:" << endl;
+      fmt = boost::format("%-3d %-20.20s %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f");
+      ret << (fmt % "#" % "Address" % "Connnections" % "Died reading query" % "Died sending response" % "Gave up" % "Client timeouts" % "Downstream timeouts" % "Avg queries/conn" % "Avg duration") << endl;
+
+      size_t counter = 0;
+      for(const auto& f : g_frontends) {
+        ret << (fmt % counter % f->local.toStringWithPort() % f->tcpCurrentConnections % f->tcpDiedReadingQuery % f->tcpDiedSendingResponse % f->tcpGaveUp % f->tcpClientTimeouts % f->tcpDownstreamTimeouts % f->tcpAvgQueriesPerConnection % f->tcpAvgConnectionDuration) << endl;
+        ++counter;
+      }
+      ret << endl;
+
+      ret << "Backends:" << endl;
+      fmt = boost::format("%-3d %-20.20s %-20.20s %-20d %-20d %-25d %-20d %-20d %-20d %-20f %-20f");
+      ret << (fmt % "#" % "Name" % "Address" % "Connections" % "Died sending query" % "Died reading response" % "Gave up" % "Read timeouts" % "Write timeouts" % "Avg queries/conn" % "Avg duration") << endl;
+
+      auto states = g_dstates.getLocal();
+      counter = 0;
+      for(const auto& s : *states) {
+        ret << (fmt % counter % s->name % s->remote.toStringWithPort() % s->tcpCurrentConnections % s->tcpDiedSendingQuery % s->tcpDiedReadingResponse % s->tcpGaveUp % s->tcpReadTimeouts % s->tcpWriteTimeouts % s->tcpAvgQueriesPerConnection % s->tcpAvgConnectionDuration) << endl;
+        ++counter;
+      }
+
+      g_outputBuffer=ret.str();
     });
 
   g_lua.writeFunction("dumpStats", [] {
@@ -660,6 +688,16 @@ void setupLuaInspection()
         group->setResponseByteRate(rate, warningRate ? *warningRate : 0, seconds, reason, blockDuration, action ? *action : DNSAction::Action::None);
       }
     });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, DynBlockRulesGroup::smtVisitor_t)>("setSuffixMatchRule", [](std::shared_ptr<DynBlockRulesGroup>& group, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, DynBlockRulesGroup::smtVisitor_t visitor) {
+      if (group) {
+        group->setSuffixMatchRule(seconds, reason, blockDuration, action ? *action : DNSAction::Action::None, visitor);
+      }
+    });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, dnsdist_ffi_stat_node_visitor_t)>("setSuffixMatchRuleFFI", [](std::shared_ptr<DynBlockRulesGroup>& group, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, dnsdist_ffi_stat_node_visitor_t visitor) {
+      if (group) {
+        group->setSuffixMatchRuleFFI(seconds, reason, blockDuration, action ? *action : DNSAction::Action::None, visitor);
+      }
+    });
   g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(uint8_t, unsigned int, unsigned int, const std::string&, unsigned int, boost::optional<DNSAction::Action>, boost::optional<unsigned int>)>("setRCodeRate", [](std::shared_ptr<DynBlockRulesGroup>& group, uint8_t rcode, unsigned int rate, unsigned int seconds, const std::string& reason, unsigned int blockDuration, boost::optional<DNSAction::Action> action, boost::optional<unsigned int> warningRate) {
       if (group) {
         group->setRCodeRate(rcode, rate, warningRate ? *warningRate : 0, seconds, reason, blockDuration, action ? *action : DNSAction::Action::None);
@@ -690,6 +728,16 @@ void setupLuaInspection()
         group->includeRange(Netmask(*boost::get<std::string>(&ranges)));
       }
     });
+  g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(boost::variant<std::string, std::vector<std::pair<int, std::string>>>)>("excludeDomains", [](std::shared_ptr<DynBlockRulesGroup>& group, boost::variant<std::string, std::vector<std::pair<int, std::string>>> domains) {
+      if (domains.type() == typeid(std::vector<std::pair<int, std::string>>)) {
+        for (const auto& range : *boost::get<std::vector<std::pair<int, std::string>>>(&domains)) {
+          group->excludeDomain(DNSName(range.second));
+        }
+      }
+      else {
+        group->excludeDomain(DNSName(*boost::get<std::string>(&domains)));
+      }
+    });
   g_lua.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)()>("apply", [](std::shared_ptr<DynBlockRulesGroup>& group) {
     group->apply();
   });
index d4d199436e1e78634e290a21d63fc4b26b9544bf..9cf840093155c51915a6904ab8d83c8cbd881bcf 100644 (file)
@@ -22,6 +22,8 @@
 #include "dnsdist.hh"
 #include "ednsoptions.hh"
 
+#undef BADSIG  // signal.h SIG_ERR
+
 void setupLuaVars()
 {
   g_lua.writeVariable("DNSAction", std::unordered_map<string,int>{
index c7e03da7b3bdafb2f965ba23744860a83e0a2687..69e5c571f67178cb8e6242c0d19a180e517382b9 100644 (file)
@@ -1268,13 +1268,13 @@ void setupLuaConfig(bool client)
       setLuaNoSideEffect();
       try {
         ostringstream ret;
-        boost::format fmt("%1$-3d %2$-20.20s %|25t|%3$-8.8s %|35t|%4%" );
+        boost::format fmt("%1$-3d %2$-20.20s %|35t|%3$-20.20s %|57t|%4%" );
         //             1    2           3            4
         ret << (fmt % "#" % "Address" % "Protocol" % "Queries" ) << endl;
 
         size_t counter = 0;
         for (const auto& front : g_frontends) {
-          ret << (fmt % counter % front->local.toStringWithPort() % (front->udpFD != -1 ? "UDP" : "TCP") % front->queries) << endl;
+          ret << (fmt % counter % front->local.toStringWithPort() % front->getType() % front->queries) << endl;
           counter++;
         }
         g_outputBuffer=ret.str();
@@ -1609,6 +1609,15 @@ void setupLuaConfig(bool client)
       g_secPollInterval = newInterval;
   });
 
+  g_lua.writeFunction("setSyslogFacility", [](int facility) {
+    setLuaSideEffect();
+    if (g_configurationDone) {
+      g_outputBuffer="setSyslogFacility cannot be used at runtime!\n";
+      return;
+    }
+    setSyslogFacility(facility);
+  });
+
   g_lua.writeFunction("addTLSLocal", [client](const std::string& addr, boost::variant<std::string, std::vector<std::pair<int,std::string>>> certFiles, boost::variant<std::string, std::vector<std::pair<int,std::string>>> keyFiles, boost::optional<localbind_t> vars) {
         if (client)
           return;
index 777865ff162b5f130234964f0f21f6beadc58984..878f056a02ec68844874abe230a466a4ab65eb43 100644 (file)
@@ -35,6 +35,8 @@
 #include <atomic>
 #include <netinet/tcp.h>
 
+#include "sstuff.hh"
+
 using std::thread;
 using std::atomic;
 
@@ -53,49 +55,168 @@ using std::atomic;
    Let's start naively.
 */
 
-static int setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures)
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+static const size_t g_maxCachedConnectionsPerDownstream = 20;
+uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+uint16_t g_downstreamTCPCleanupInterval{60};
+bool g_useTCPSinglePipe{false};
+
+static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
 {
+  std::unique_ptr<Socket> result;
+
   do {
     vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
-    int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+    result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
     try {
       if (!IsAnyAddress(ds->sourceAddr)) {
-        SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
+        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef IP_BIND_ADDRESS_NO_PORT
         if (ds->ipBindAddrNoPort) {
-          SSetsockopt(sock, SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
         }
 #endif
-        SBind(sock, ds->sourceAddr);
+        result->bind(ds->sourceAddr, false);
       }
-      setNonBlocking(sock);
+      result->setNonBlocking();
 #ifdef MSG_FASTOPEN
       if (!ds->tcpFastOpen) {
-        SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+        SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
       }
 #else
-      SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+      SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
 #endif /* MSG_FASTOPEN */
-      return sock;
+      return result;
     }
     catch(const std::runtime_error& e) {
-      /* don't leak our file descriptor if SConnect() (for example) throws */
+      vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
       downstreamFailures++;
-      close(sock);
       if (downstreamFailures > ds->retries) {
         throw;
       }
     }
   } while(downstreamFailures <= ds->retries);
 
-  return -1;
+  return nullptr;
+}
+
+class TCPConnectionToBackend
+{
+public:
+  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now)
+  {
+    d_socket = setupTCPDownstream(d_ds, downstreamFailures);
+    ++d_ds->tcpCurrentConnections;
+  }
+
+  ~TCPConnectionToBackend()
+  {
+    if (d_ds && d_socket) {
+      --d_ds->tcpCurrentConnections;
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      auto diff = now - d_connectionStartTime;
+      d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
+    }
+  }
+
+  int getHandle() const
+  {
+    if (!d_socket) {
+      throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
+    }
+
+    return d_socket->getHandle();
+  }
+
+  const ComboAddress& getRemote() const
+  {
+    return d_ds->remote;
+  }
+
+  bool isFresh() const
+  {
+    return d_fresh;
+  }
+
+  void incQueries()
+  {
+    ++d_queries;
+  }
+
+  void setReused()
+  {
+    d_fresh = false;
+  }
+
+private:
+  std::unique_ptr<Socket> d_socket{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  struct timeval d_connectionStartTime;
+  uint64_t d_queries{0};
+  bool d_fresh{true};
+};
+
+static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+
+static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
+{
+  std::unique_ptr<TCPConnectionToBackend> result;
+
+  const auto& it = t_downstreamConnections.find(ds->remote);
+  if (it != t_downstreamConnections.end()) {
+    auto& list = it->second;
+    if (!list.empty()) {
+      result = std::move(list.front());
+      list.pop_front();
+      result->setReused();
+      return result;
+    }
+  }
+
+  return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
+}
+
+static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
+{
+  if (conn == nullptr) {
+    return;
+  }
+
+  const auto& remote = conn->getRemote();
+  const auto& it = t_downstreamConnections.find(remote);
+  if (it != t_downstreamConnections.end()) {
+    auto& list = it->second;
+    if (list.size() >= g_maxCachedConnectionsPerDownstream) {
+      /* too many connections queued already */
+      conn.reset();
+      return;
+    }
+    list.push_back(std::move(conn));
+  }
+  else {
+    t_downstreamConnections[remote].push_back(std::move(conn));
+  }
 }
 
 struct ConnectionInfo
 {
-  ConnectionInfo(): cs(nullptr), fd(-1)
+  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
   {
   }
+  ConnectionInfo(ConnectionInfo&& rhs)
+  {
+    remote = rhs.remote;
+    cs = rhs.cs;
+    rhs.cs = nullptr;
+    fd = rhs.fd;
+    rhs.fd = -1;
+  }
 
   ConnectionInfo(const ConnectionInfo& rhs) = delete;
   ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
@@ -116,6 +237,9 @@ struct ConnectionInfo
       close(fd);
       fd = -1;
     }
+    if (cs) {
+      --cs->tcpCurrentConnections;
+    }
   }
 
   ComboAddress remote;
@@ -123,15 +247,6 @@ struct ConnectionInfo
   int fd{-1};
 };
 
-uint64_t g_maxTCPQueuedConnections{1000};
-size_t g_maxTCPQueriesPerConn{0};
-size_t g_maxTCPConnectionDuration{0};
-size_t g_maxTCPConnectionsPerClient{0};
-static std::mutex tcpClientsCountMutex;
-static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-bool g_useTCPSinglePipe{false};
-std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
-
 void tcpClientThread(int pipefd);
 
 static void decrementTCPClientCount(const ComboAddress& client)
@@ -161,6 +276,13 @@ void TCPClientCollection::addTCPClientThread()
       return;
     }
 
+    if (!setNonBlocking(pipefds[0])) {
+      close(pipefds[0]);
+      close(pipefds[1]);
+      errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
+      return;
+    }
+
     if (!setNonBlocking(pipefds[1])) {
       close(pipefds[0]);
       close(pipefds[1]);
@@ -201,544 +323,914 @@ void TCPClientCollection::addTCPClientThread()
   ++d_numthreads;
 }
 
-static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
-try
+static void cleanupClosedTCPConnections()
 {
-  uint16_t raw;
-  size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
-}
-
-static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len)
-try
-{
-  uint16_t raw;
-  size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
-}
+  for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
+    for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
+      if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
+        ++connIt;
+      }
+      else {
+        connIt = dsIt->second.erase(connIt);
+      }
+    }
 
-static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
-{
-  if (maxConnectionDuration) {
-    time_t curtime = time(nullptr);
-    unsigned int elapsed = 0;
-    if (curtime > start) { // To prevent issues when time goes backward
-      elapsed = curtime - start;
+    if (!dsIt->second.empty()) {
+      ++dsIt;
     }
-    if (elapsed >= maxConnectionDuration) {
-      return true;
+    else {
+      dsIt = t_downstreamConnections.erase(dsIt);
     }
-    remainingTime = maxConnectionDuration - elapsed;
   }
-  return false;
 }
 
-void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
+/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+   Updates pos everytime a successful read occurs,
+   throws an std::runtime_error in case of IO error,
+   return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+   would block.
+*/
+// XXX could probably be implemented as a TCPIOHandler
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
 {
-  for(auto it = sockets.begin(); it != sockets.end(); ) {
-    if (isTCPSocketUsable(it->second)) {
-      ++it;
+  if (buffer.size() < (pos + toRead)) {
+    throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
+  }
+
+  size_t got = 0;
+  do {
+    ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+    if (res == 0) {
+      throw runtime_error("EOF while reading message");
     }
-    else {
-      close(it->second);
-      it = sockets.erase(it);
+    if (res < 0) {
+      if (errno == EAGAIN || errno == EWOULDBLOCK) {
+        return IOState::NeedRead;
+      }
+      else {
+        throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+      }
     }
+
+    pos += static_cast<size_t>(res);
+    got += static_cast<size_t>(res);
   }
+  while (got < toRead);
+
+  return IOState::Done;
 }
 
-std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
 
-void tcpClientThread(int pipefd)
+class TCPClientThreadData
 {
-  /* we get launched with a pipe on which we receive file descriptors from clients that we own
-     from that point on */
-
-  setThreadName("dnsdist/tcpClie");
+public:
+  TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  {
+  }
 
-  bool outstanding = false;
-  time_t lastTCPCleanup = time(nullptr);
-  
   LocalHolders holders;
-  auto localRespRulactions = g_resprulactions.getLocal();
-#ifdef HAVE_DNSCRYPT
-  /* when the answer is encrypted in place, we need to get a copy
-     of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
-#endif
+  LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
+  std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+};
 
-  map<ComboAddress,int> sockets;
-  for(;;) {
-    ConnectionInfo* citmp, ci;
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+class IncomingTCPConnectionState
+{
+public:
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
+  {
+    d_ids.origDest.reset();
+    d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+    socklen_t socklen = d_ids.origDest.getSocklen();
+    if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
+      d_ids.origDest = d_ci.cs->local;
+    }
+  }
+
+  IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+  IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+  ~IncomingTCPConnectionState()
+  {
+    decrementTCPClientCount(d_ci.remote);
+    if (d_ci.cs != nullptr) {
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      auto diff = now - d_connectionStartTime;
+      d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
+    }
+
+    if (d_ds != nullptr) {
+      if (d_outstanding) {
+        --d_ds->outstanding;
+        d_outstanding = false;
+      }
+
+      if (d_downstreamConnection) {
+        try {
+          if (d_lastIOState == IOState::NeedRead) {
+            cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
+            d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
+          }
+          else if (d_lastIOState == IOState::NeedWrite) {
+            cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
+            d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
+          }
+        }
+        catch(const FDMultiplexerException& e) {
+          vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+        }
+      }
+    }
 
     try {
-      readn2(pipefd, &citmp, sizeof(citmp));
+      if (d_lastIOState == IOState::NeedRead) {
+        cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeReadFD(d_ci.fd);
+      }
+      else if (d_lastIOState == IOState::NeedWrite) {
+        cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeWriteFD(d_ci.fd);
+      }
     }
-    catch(const std::runtime_error& e) {
-      throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
+    catch(const FDMultiplexerException& e) {
+      vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
     }
+  }
 
-    g_tcpclientthreads->decrementQueuedCount();
-    ci=std::move(*citmp);
-    delete citmp;
+  void resetForNewQuery()
+  {
+    d_buffer.resize(sizeof(uint16_t));
+    d_currentPos = 0;
+    d_querySize = 0;
+    d_responseSize = 0;
+    d_downstreamFailures = 0;
+    d_state = State::readingQuerySize;
+    d_lastIOState = IOState::Done;
+  }
 
-    uint16_t qlen, rlen;
-    vector<uint8_t> rewrittenResponse;
-    shared_ptr<DownstreamState> ds;
-    ComboAddress dest;
-    dest.reset();
-    dest.sin4.sin_family = ci.remote.sin4.sin_family;
-    socklen_t len = dest.getSocklen();
-    size_t queriesCount = 0;
-    time_t connectionStartTime = time(NULL);
-    std::vector<char> queryBuffer;
-    std::vector<char> answerBuffer;
+  boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+      return boost::none;
+    }
 
-    if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
-      dest = ci.cs->local;
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+        return now;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+        now.tv_sec += remaining;
+        return now;
+      }
     }
 
-    try {
-      TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime);
+    now.tv_sec += g_tcpRecvTimeout;
+    return now;
+  }
 
-      for(;;) {
-        unsigned int remainingTime = 0;
-        ds = nullptr;
-        outstanding = false;
+  boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() without any backend selected");
+    }
+    if (d_ds->tcpRecvTimeout == 0) {
+      return boost::none;
+    }
 
-        if(!getNonBlockingMsgLenFromClient(handler, &qlen)) {
-          break;
-        }
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpRecvTimeout;
 
-        queriesCount++;
+    return res;
+  }
 
-        if (qlen < sizeof(dnsheader)) {
-          ++g_stats.nonCompliantQueries;
-          break;
-        }
+  boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+      return boost::none;
+    }
 
-        ci.cs->queries++;
-        ++g_stats.queries;
+    struct timeval res = now;
 
-        if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
-          break;
-        }
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+        return res;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+        res.tv_sec += remaining;
+        return res;
+      }
+    }
 
-        if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
-          break;
-        }
+    res.tv_sec += g_tcpSendTimeout;
+    return res;
+  }
 
-        bool ednsAdded = false;
-        bool ecsAdded = false;
-        /* allocate a bit more memory to be able to spoof the content,
-           or to add ECS without allocating a new buffer */
-        queryBuffer.resize(qlen + 512);
-
-        char* query = &queryBuffer[0];
-        handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
-
-        /* we need this one to be accurate ("real") for the protobuf message */
-       struct timespec queryRealTime;
-       struct timespec now;
-       gettime(&now);
-       gettime(&queryRealTime, true);
-
-#ifdef HAVE_DNSCRYPT
-        std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-        if (ci.cs->dnscryptCtx) {
-          dnsCryptQuery = std::make_shared<DNSCryptQuery>(ci.cs->dnscryptCtx);
-          uint16_t decryptedQueryLen = 0;
-          vector<uint8_t> response;
-          bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response);
-
-          if (!decrypted) {
-            if (response.size() > 0) {
-              handler.writeSizeAndMsg(response.data(), response.size(), g_tcpSendTimeout);
-            }
-            break;
-          }
-          qlen = decryptedQueryLen;
-        }
-#endif
-        struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+  boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() called without any backend selected");
+    }
+    if (d_ds->tcpSendTimeout == 0) {
+      return boost::none;
+    }
 
-        if (!checkQueryHeaders(dh)) {
-          goto drop;
-        }
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpSendTimeout;
 
-       string poolname;
-       int delayMsec=0;
+    return res;
+  }
 
-       const uint16_t* flags = getFlagsFromDNSHeader(dh);
-       uint16_t origFlags = *flags;
-       uint16_t qtype, qclass;
-       unsigned int consumed = 0;
-       DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-       DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+  bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
+  {
+    if (maxConnectionDuration) {
+      time_t curtime = now.tv_sec;
+      unsigned int elapsed = 0;
+      if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
+        elapsed = curtime - d_connectionStartTime.tv_sec;
+      }
+      if (elapsed >= maxConnectionDuration) {
+        return true;
+      }
+      d_remainingTime = maxConnectionDuration - elapsed;
+    }
 
-       if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
-         goto drop;
-       }
+    return false;
+  }
 
-       if(dq.dh->qr) { // something turned it into a response
-          fixUpQueryTurnedResponse(dq, origFlags);
+  void dump() const
+  {
+    static std::mutex s_mutex;
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+
+    {
+      std::lock_guard<std::mutex> lock(s_mutex);
+      fprintf(stderr, "State is %p\n", this);
+      cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
+      cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
+      cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
+      if (d_state > State::doingHandshake) {
+        cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuery) {
+        cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
+      }
+      if (d_state > State::sendingQueryToBackend) {
+        cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
+      }
+      if (d_state > State::readingResponseFromBackend) {
+        cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
+      }
+    }
+  }
 
-          DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-          dr.uniqueId = dq.uniqueId;
-#endif
-          dr.qTag = dq.qTag;
+  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
+
+  std::vector<uint8_t> d_buffer;
+  std::vector<uint8_t> d_responseBuffer;
+  TCPClientThreadData& d_threadData;
+  IDState d_ids;
+  ConnectionInfo d_ci;
+  TCPIOHandler d_handler;
+  std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  struct timeval d_connectionStartTime;
+  struct timeval d_handshakeDoneTime;
+  struct timeval d_firstQuerySizeReadTime;
+  struct timeval d_querySizeReadTime;
+  struct timeval d_queryReadTime;
+  struct timeval d_querySentTime;
+  struct timeval d_responseReadTime;
+  size_t d_currentPos{0};
+  size_t d_queriesCount{0};
+  unsigned int d_remainingTime{0};
+  uint16_t d_querySize{0};
+  uint16_t d_responseSize{0};
+  uint16_t d_downstreamFailures{0};
+  State d_state{State::doingHandshake};
+  IOState d_lastIOState{IOState::Done};
+  bool d_readingFirstQuery{true};
+  bool d_outstanding{false};
+  bool d_firstResponsePacket{true};
+  bool d_isXFR{false};
+  bool d_xfrStarted{false};
+};
 
-          if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-            goto drop;
-          }
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
 
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-            goto drop;
-          }
-#endif
-          handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
-          ++g_stats.selfAnswered;
-          continue;
-        }
+static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+
+  if (state->d_isXFR && state->d_downstreamConnection) {
+    /* we need to resume reading from the backend! */
+    state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+    state->d_currentPos = 0;
+    handleDownstreamIO(state, now);
+    return;
+  }
 
-        std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-        std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+  if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+    return;
+  }
 
-        auto policy = *(holders.policy);
-        if (serverPool->policy != nullptr) {
-          policy = *(serverPool->policy);
-        }
-        auto servers = serverPool->getServers();
-        if (policy.isLua) {
-          std::lock_guard<std::mutex> lock(g_luamutex);
-          ds = policy.policy(servers, &dq);
-        }
-        else {
-          ds = policy.policy(servers, &dq);
-        }
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    return;
+  }
 
-        uint32_t cacheKeyNoECS = 0;
-        uint32_t cacheKey = 0;
-        boost::optional<Netmask> subnet;
-        char cachedResponse[4096];
-        uint16_t cachedResponseSize = sizeof cachedResponse;
-        uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
-        bool useZeroScope = false;
-
-        bool dnssecOK = false;
-        if (packetCache && !dq.skipCache) {
-          dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
-        }
+  state->resetForNewQuery();
 
-        if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
-          // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-          if (packetCache && !dq.skipCache && (!ds || !ds->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-            if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-              DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-              dr.uniqueId = dq.uniqueId;
-#endif
-              dr.qTag = dq.qTag;
+  handleIO(state, now);
+}
 
-              if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-                goto drop;
-              }
+static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
 
-#ifdef HAVE_DNSCRYPT
-              if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-                goto drop;
-              }
-#endif
-              handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-              g_stats.cacheHits++;
-              continue;
-            }
-
-            if (!subnet) {
-              /* there was no existing ECS on the query, enable the zero-scope feature */
-              useZeroScope = true;
-            }
-          }
+  state->d_currentPos = 0;
 
-          if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-            vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
-            goto drop;
-          }
-        }
+  handleIO(state, now);
+}
 
-        if (packetCache && !dq.skipCache) {
-          if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
+static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_responseSize < sizeof(dnsheader)) {
+    return;
+  }
 
-            if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
+  auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+  unsigned int consumed;
+  if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+    return;
+  }
+  state->d_firstResponsePacket = false;
 
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
-            ++g_stats.cacheHits;
-            continue;
-          }
-          ++g_stats.cacheMisses;
-        }
+  if (state->d_outstanding) {
+    --state->d_ds->outstanding;
+    state->d_outstanding = false;
+  }
 
-        if(!ds) {
-          ++g_stats.noPolicy;
+  auto dh = reinterpret_cast<struct dnsheader*>(response);
+  uint16_t addRoom = 0;
+  DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+  if (dr.dnsCryptQuery) {
+    addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+  }
 
-          if (g_servFailOnNoPolicy) {
-            restoreFlags(dh, origFlags);
-            dq.dh->rcode = RCode::ServFail;
-            dq.dh->qr = true;
+  dnsheader cleartextDH;
+  memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
 
-            DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-            dr.uniqueId = dq.uniqueId;
-#endif
-            dr.qTag = dq.qTag;
+  std::vector<uint8_t> rewrittenResponse;
+  size_t responseSize = state->d_responseBuffer.size();
+  if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+    return;
+  }
 
-            if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-              goto drop;
-            }
+  if (!rewrittenResponse.empty()) {
+    /* responseSize has been updated as well but we don't really care since it will match
+       the capacity of rewrittenResponse anyway */
+    state->d_responseBuffer = std::move(rewrittenResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+  } else {
+    /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
+    state->d_responseBuffer.resize(state->d_responseSize);
+  }
 
-#ifdef HAVE_DNSCRYPT
-            if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
-              goto drop;
-            }
-#endif
-            handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
+  if (state->d_isXFR && !state->d_xfrStarted) {
+    /* don't bother parsing the content of the response for now */
+    state->d_xfrStarted = true;
+  }
 
-            // no response-only statistics counter to update.
-            continue;
-          }
+  sendResponse(state, now);
 
-          break;
-        }
+  ++g_stats.responses;
+  struct timespec answertime;
+  gettime(&answertime);
+  double udiff = state->d_ids.sentTime.udiff();
+  g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
+}
 
-        if (dq.addXPF && ds->xpfRRCode != 0) {
-          addXPF(dq, ds->xpfRRCode, g_preserveTrailingData);
-        }
+static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  auto ds = state->d_ds;
+  state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
+  state->d_currentPos = 0;
+  state->d_firstResponsePacket = true;
+  state->d_downstreamConnection.reset();
+
+  if (state->d_xfrStarted) {
+    /* sorry, but we are not going to resume a XFR if we have already sent some packets
+       to the client */
+    return;
+  }
 
-       int dsock = -1;
-       uint16_t downstreamFailures=0;
-#ifdef MSG_FASTOPEN
-       bool freshConn = true;
-#endif /* MSG_FASTOPEN */
-       if(sockets.count(ds->remote) == 0) {
-         dsock=setupTCPDownstream(ds, downstreamFailures);
-         sockets[ds->remote]=dsock;
-       }
-       else {
-         dsock=sockets[ds->remote];
-#ifdef MSG_FASTOPEN
-         freshConn = false;
-#endif /* MSG_FASTOPEN */
-        }
+  while (state->d_downstreamFailures < state->d_ds->retries)
+  {
+    state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
 
-        ds->queries++;
-        ds->outstanding++;
-        outstanding = true;
+    if (!state->d_downstreamConnection) {
+      ++ds->tcpGaveUp;
+      ++state->d_ci.cs->tcpGaveUp;
+      vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+      return;
+    }
 
-      retry:; 
-        if (dsock < 0) {
-          sockets.erase(ds->remote);
-          break;
-        }
+    handleDownstreamIO(state, now);
+    return;
+  }
 
-        if (ds->retries > 0 && downstreamFailures > ds->retries) {
-          vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstreamFailures);
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          break;
-        }
+  ++ds->tcpGaveUp;
+  ++state->d_ci.cs->tcpGaveUp;
+  vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+}
 
-        try {
-          int socketFlags = 0;
-#ifdef MSG_FASTOPEN
-          if (ds->tcpFastOpen && freshConn) {
-            socketFlags |= MSG_FASTOPEN;
-          }
-#endif /* MSG_FASTOPEN */
-          sendSizeAndMsgWithTimeout(dsock, dq.len, query, ds->tcpSendTimeout, &ds->remote, &ds->sourceAddr, ds->sourceItf, 0, socketFlags);
-        }
-        catch(const runtime_error& e) {
-          vinfolog("Downstream connection to %s died on us (%s), getting a new one!", ds->getName(), e.what());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
-#ifdef MSG_FASTOPEN
-          freshConn=true;
-#endif /* MSG_FASTOPEN */
-          goto retry;
-        }
+static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_querySize < sizeof(dnsheader)) {
+    ++g_stats.nonCompliantQueries;
+    return;
+  }
 
-        bool xfrStarted = false;
-        bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
-        if (isXFR) {
-          dq.skipCache = true;
-        }
-        bool firstPacket=true;
-      getpacket:;
-
-        if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
-         vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
+  state->d_readingFirstQuery = false;
+  ++state->d_queriesCount;
+  ++state->d_ci.cs->queries;
+  ++g_stats.queries;
+
+  /* we need an accurate ("real") value for the response and
+     to store into the IDS, but not for insertion into the
+     rings for example */
+  struct timespec queryRealTime;
+  gettime(&queryRealTime, true);
+
+  auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
+  if (dnsCryptResponse) {
+    state->d_responseBuffer = std::move(*dnsCryptResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state, now);
+    return;
+  }
+
+  const auto& dh = reinterpret_cast<dnsheader*>(query);
+  if (!checkQueryHeaders(dh)) {
+    return;
+  }
+
+  uint16_t qtype, qclass;
+  unsigned int consumed = 0;
+  DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+  dq.dnsCryptQuery = std::move(dnsCryptQuery);
+
+  state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
+  if (state->d_isXFR) {
+    dq.skipCache = true;
+  }
+
+  state->d_ds.reset();
+  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
+
+  if (result == ProcessQueryResult::Drop) {
+    return;
+  }
+
+  if (result == ProcessQueryResult::SendAnswer) {
+    state->d_buffer.resize(dq.len);
+    state->d_responseBuffer = std::move(state->d_buffer);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state, now);
+    return;
+  }
+
+  if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
+    return;
+  }
+
+  state->d_buffer.resize(dq.len);
+  setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
+
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
+  sendQueryToBackend(state, now);
+}
+
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
+{
+  //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
+
+  if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
+    state->d_threadData.mplexer->removeReadFD(fd);
+    //cerr<<__func__<<": remove read FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
+  else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
+    state->d_threadData.mplexer->removeWriteFD(fd);
+    //cerr<<__func__<<": remove write FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
+
+  if (iostate == IOState::NeedRead) {
+    if (state->d_lastIOState == IOState::NeedRead) {
+      if (ttd) {
+        /* let's update the TTD ! */
+        state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+      }
+      return;
+    }
+
+    state->d_lastIOState = IOState::NeedRead;
+    //cerr<<__func__<<": add read FD "<<fd<<endl;
+    state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::NeedWrite) {
+    if (state->d_lastIOState == IOState::NeedWrite) {
+      return;
+    }
+
+    state->d_lastIOState = IOState::NeedWrite;
+    //cerr<<__func__<<": add write FD "<<fd<<endl;
+    state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::Done) {
+    state->d_lastIOState = IOState::Done;
+  }
+}
+
+static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  if (state->d_downstreamConnection == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+
+  int fd = state->d_downstreamConnection->getHandle();
+  IOState iostate = IOState::Done;
+  bool connectionDied = false;
+
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+      int socketFlags = 0;
 #ifdef MSG_FASTOPEN
-          freshConn=true;
+      if (state->d_ds->tcpFastOpen && state->d_downstreamConnection->isFresh()) {
+        socketFlags |= MSG_FASTOPEN;
+      }
 #endif /* MSG_FASTOPEN */
-          if(xfrStarted) {
-            goto drop;
-          }
-          goto retry;
-        }
 
-        size_t responseSize = rlen;
-        uint16_t addRoom = 0;
-#ifdef HAVE_DNSCRYPT
-        if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
-          addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
-        }
-#endif
-        responseSize += addRoom;
-        answerBuffer.resize(responseSize);
-        char* response = answerBuffer.data();
-        readn2WithTimeout(dsock, response, rlen, ds->tcpRecvTimeout);
-        uint16_t responseLen = rlen;
-        if (outstanding) {
-          /* might be false for {A,I}XFR */
-          --ds->outstanding;
-          outstanding = false;
+      size_t sent = sendMsgWithTimeout(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, 0, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, 0, socketFlags);
+      if (sent == state->d_buffer.size()) {
+        /* request sent ! */
+        state->d_downstreamConnection->incQueries();
+        state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+        state->d_currentPos = 0;
+        state->d_querySentTime = now;
+        iostate = IOState::NeedRead;
+        if (!state->d_isXFR) {
+          /* don't bother with the outstanding count for XFR queries */
+          ++state->d_ds->outstanding;
+          state->d_outstanding = true;
         }
+      }
+      else {
+        state->d_currentPos += sent;
+        iostate = IOState::NeedWrite;
+        /* disable fast open on partial write */
+        state->d_downstreamConnection->setReused();
+      }
+    }
 
-        if (rlen < sizeof(dnsheader)) {
-          break;
-        }
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
+      // then we need to allocate a new buffer (new because we might need to re-send the query if the
+      // backend dies on us
+      // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+      // should very likely be a TCPIOHandler d_downstreamHandler
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
+        state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
+        state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
+        state->d_currentPos = 0;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
 
-        consumed = 0;
-        if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) {
-          break;
+        if (state->d_isXFR) {
+          /* Don't reuse the TCP connection after an {A,I}XFR */
+          /* but don't reset it either, we will need to read more messages */
         }
-        firstPacket=false;
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom, useZeroScope ? &zeroScope : nullptr)) {
-          break;
+        else {
+          releaseDownstreamConnection(std::move(state->d_downstreamConnection));
         }
+        fd = -1;
 
-        dh = (struct dnsheader*) response;
-        DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = dq.uniqueId;
-#endif
-        dr.qTag = dq.qTag;
+        state->d_responseReadTime = now;
+        handleResponse(state, now);
+        return;
+      }
+    }
 
-        if (!processResponse(localRespRulactions, dr, &delayMsec)) {
-          break;
-        }
+    if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
+    }
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
+    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+      ++state->d_ds->tcpDiedSendingQuery;
+    }
+    else {
+      ++state->d_ds->tcpDiedReadingResponse;
+    }
 
-       if (packetCache && !dq.skipCache) {
-          if (!useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          packetCache->insert(zeroScope ? cacheKeyNoECS : cacheKey, zeroScope ? boost::none : subnet, origFlags, dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
-       }
+    /* don't increase this counter when reusing connections */
+    if (state->d_downstreamConnection->isFresh()) {
+      ++state->d_downstreamFailures;
+    }
+    if (state->d_outstanding && state->d_ds != nullptr) {
+      --state->d_ds->outstanding;
+      state->d_outstanding = false;
+    }
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+    connectionDied = true;
+  }
+
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
+  }
+
+  if (connectionDied) {
+    sendQueryToBackend(state, now);
+  }
+}
+
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (state->d_downstreamConnection == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+  if (fd != state->d_downstreamConnection->getHandle()) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
+  }
 
-#ifdef HAVE_DNSCRYPT
-        if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery, &dh, &dhCopy)) {
-          goto drop;
+  struct timeval now;
+  gettimeofday(&now, 0);
+  handleDownstreamIO(state, now);
+}
+
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+{
+  int fd = state->d_ci.fd;
+  IOState iostate = IOState::Done;
+
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+    return;
+  }
+
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+      iostate = state->d_handler.tryHandshake();
+      if (iostate == IOState::Done) {
+        state->d_handshakeDoneTime = now;
+        state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingQuery;
+        state->d_querySizeReadTime = now;
+        if (state->d_queriesCount == 0) {
+          state->d_firstQuerySizeReadTime = now;
         }
-#endif
-        if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
-          break;
+        state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+        if (state->d_querySize < sizeof(dnsheader)) {
+          /* go away */
+          handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+          return;
         }
 
-        if (isXFR) {
-          if (dh->rcode == 0 && dh->ancount != 0) {
-            if (xfrStarted == false) {
-              xfrStarted = true;
-              if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
-                goto getpacket;
-              }
-            }
-            else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
-              goto getpacket;
-            }
-          }
-          /* Don't reuse the TCP connection after an {A,I}XFR */
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-        }
+        /* allocate a bit more memory to be able to spoof the content,
+           or to add ECS without allocating a new buffer */
+        state->d_buffer.resize(state->d_querySize + 512);
+        state->d_currentPos = 0;
+      }
+    }
 
-        ++g_stats.responses;
-        struct timespec answertime;
-        gettime(&answertime);
-        unsigned int udiff = 1000000.0*DiffTime(now,answertime);
-        g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote);
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+        handleQuery(state, now);
+        return;
+      }
+    }
 
-        rewrittenResponse.clear();
+    if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
+      if (iostate == IOState::Done) {
+        handleResponseSent(state, now);
+        return;
       }
     }
-    catch(...) {}
 
-  drop:;
+    if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+        state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+      vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
+    }
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
+        state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+        state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      ++state->d_ci.cs->tcpDiedReadingQuery;
+    }
+    else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      ++state->d_ci.cs->tcpDiedSendingResponse;
+    }
+
+    if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+      vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+    }
+    else {
+      vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+    }
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+  }
+
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+  }
+}
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (fd != state->d_ci.fd) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
+  }
+  struct timeval now;
+  gettimeofday(&now, 0);
+
+  handleIO(state, now);
+}
+
+static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+  auto threadData = boost::any_cast<TCPClientThreadData*>(param);
+
+  ConnectionInfo* citmp{nullptr};
+
+  ssize_t got = read(pipefd, &citmp, sizeof(citmp));
+  if (got == 0) {
+    throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+  else if (got == -1) {
+    if (errno == EAGAIN || errno == EINTR) {
+      return;
+    }
+    throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + strerror(errno));
+  }
+  else if (got != sizeof(citmp)) {
+    throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  }
+
+  try {
+    g_tcpclientthreads->decrementQueuedCount();
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
+    delete citmp;
+    citmp = nullptr;
+
+    /* let's update the remaining time */
+    state->d_remainingTime = g_maxTCPConnectionDuration;
+
+    handleIO(state, now);
+  }
+  catch(...) {
+    delete citmp;
+    citmp = nullptr;
+    throw;
+  }
+}
 
-    vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
+void tcpClientThread(int pipefd)
+{
+  /* we get launched with a pipe on which we receive file descriptors from clients that we own
+     from that point on */
+
+  setThreadName("dnsdist/tcpClie");
 
-    if (ds && outstanding) {
-      outstanding = false;
-      --ds->outstanding;
+  TCPClientThreadData data;
+
+  data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
+  struct timeval now;
+  gettimeofday(&now, 0);
+  time_t lastTCPCleanup = now.tv_sec;
+  time_t lastTimeoutScan = now.tv_sec;
+
+  for (;;) {
+    data.mplexer->run(&now);
+
+    if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
+      cleanupClosedTCPConnections();
+      lastTCPCleanup = now.tv_sec;
     }
-    decrementTCPClientCount(ci.remote);
 
-    if (g_downstreamTCPCleanupInterval > 0 && (connectionStartTime > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
-      cleanupClosedTCPConnections(sockets);
-      lastTCPCleanup = time(nullptr);
+    if (now.tv_sec > lastTimeoutScan) {
+      lastTimeoutScan = now.tv_sec;
+      auto expiredReadConns = data.mplexer->getTimeouts(now, false);
+      for(const auto& conn : expiredReadConns) {
+        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+        if (conn.first == state->d_ci.fd) {
+          vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+          ++state->d_ci.cs->tcpClientTimeouts;
+        }
+        else if (state->d_ds) {
+          vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
+          ++state->d_ci.cs->tcpDownstreamTimeouts;
+          ++state->d_ds->tcpReadTimeouts;
+        }
+        data.mplexer->removeReadFD(conn.first);
+        state->d_lastIOState = IOState::Done;
+      }
+
+      auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
+      for(const auto& conn : expiredWriteConns) {
+        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+        if (conn.first == state->d_ci.fd) {
+          vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+          ++state->d_ci.cs->tcpClientTimeouts;
+        }
+        else if (state->d_ds) {
+          vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
+          ++state->d_ci.cs->tcpDownstreamTimeouts;
+          ++state->d_ds->tcpWriteTimeouts;
+        }
+        data.mplexer->removeWriteFD(conn.first);
+        state->d_lastIOState = IOState::Done;
+      }
     }
   }
 }
 
-/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and 
+/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
    they will hand off to worker threads & spawn more of them if required
 */
 void tcpAcceptorThread(void* p)
@@ -748,7 +1240,7 @@ void tcpAcceptorThread(void* p)
   bool tcpClientCountIncremented = false;
   ComboAddress remote;
   remote.sin4.sin_family = cs->local.sin4.sin_family;
-  
+
   g_tcpclientthreads->addTCPClientThread();
 
   auto acl = g_ACL.getLocal();
@@ -758,13 +1250,14 @@ void tcpAcceptorThread(void* p)
     tcpClientCountIncremented = false;
     try {
       socklen_t remlen = remote.getSocklen();
-      ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo);
-      ci->cs = cs;
+      ci = std::unique_ptr<ConnectionInfo>(new ConnectionInfo(cs));
 #ifdef HAVE_ACCEPT4
-      ci->fd = accept4(cs->tcpFD, (struct sockaddr*)&remote, &remlen, SOCK_NONBLOCK);
+      ci->fd = accept4(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
 #else
-      ci->fd = accept(cs->tcpFD, (struct sockaddr*)&remote, &remlen);
+      ci->fd = accept(cs->tcpFD, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
 #endif
+      ++cs->tcpCurrentConnections;
+
       if(ci->fd < 0) {
         throw std::runtime_error((boost::format("accepting new connection on socket: %s") % strerror(errno)).str());
       }
@@ -821,7 +1314,7 @@ void tcpAcceptorThread(void* p)
         }
       }
     }
-    catch(std::exception& e) {
+    catch(const std::exception& e) {
       errlog("While reading a TCP question: %s", e.what());
       if(tcpClientCountIncremented) {
         decrementTCPClientCount(remote);
index 9cb51a622b17d5f1b37fb251402a9737284746f7..291e16661d7cb1b7af40c158ca8fe53dc7f42552 100644 (file)
@@ -446,20 +446,36 @@ static void connectionThread(int sock, ComboAddress remote)
         auto states = g_dstates.getLocal();
         const string statesbase = "dnsdist_server_";
 
-        output << "# HELP " << statesbase << "queries "     << "Amount of queries relayed to server"                               << "\n";
-        output << "# TYPE " << statesbase << "queries "     << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "drops "       << "Amount of queries not answered by server"                          << "\n";
-        output << "# TYPE " << statesbase << "drops "       << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "latency "     << "Server's latency when answering questions in miliseconds"          << "\n";
-        output << "# TYPE " << statesbase << "latency "     << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "senderrors "  << "Total number of OS snd errors while relaying queries"              << "\n";
-        output << "# TYPE " << statesbase << "senderrors "  << "counter"                                                           << "\n";
-        output << "# HELP " << statesbase << "outstanding " << "Current number of queries that are waiting for a backend response" << "\n";
-        output << "# TYPE " << statesbase << "outstanding " << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "order "       << "The order in which this server is picked"                          << "\n";
-        output << "# TYPE " << statesbase << "order "       << "gauge"                                                             << "\n";
-        output << "# HELP " << statesbase << "weight "      << "The weight within the order in which this server is picked"        << "\n";
-        output << "# TYPE " << statesbase << "weight "      << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "queries "                << "Amount of queries relayed to server"                               << "\n";
+        output << "# TYPE " << statesbase << "queries "                << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "drops "                  << "Amount of queries not answered by server"                          << "\n";
+        output << "# TYPE " << statesbase << "drops "                  << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "latency "                << "Server's latency when answering questions in miliseconds"          << "\n";
+        output << "# TYPE " << statesbase << "latency "                << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "senderrors "             << "Total number of OS snd errors while relaying queries"              << "\n";
+        output << "# TYPE " << statesbase << "senderrors "             << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "outstanding "            << "Current number of queries that are waiting for a backend response" << "\n";
+        output << "# TYPE " << statesbase << "outstanding "            << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "order "                  << "The order in which this server is picked"                          << "\n";
+        output << "# TYPE " << statesbase << "order "                  << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "weight "                 << "The weight within the order in which this server is picked"        << "\n";
+        output << "# TYPE " << statesbase << "weight "                 << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpdiedsendingquery "    << "The number of TCP I/O errors while sending the query"              << "\n";
+        output << "# TYPE " << statesbase << "tcpdiedsendingquery "    << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpdiedreadingresponse " << "The number of TCP I/O errors while reading the response"           << "\n";
+        output << "# TYPE " << statesbase << "tcpdiedreadingresponse " << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpgaveup "              << "The number of TCP connections failing after too many attempts"     << "\n";
+        output << "# TYPE " << statesbase << "tcpgaveup "              << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpreadtimeouts "        << "The number of TCP read timeouts"                                   << "\n";
+        output << "# TYPE " << statesbase << "tcpreadtimeouts "        << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpwritetimeouts "       << "The number of TCP write timeouts"                                  << "\n";
+        output << "# TYPE " << statesbase << "tcpwritetimeouts "       << "counter"                                                           << "\n";
+        output << "# HELP " << statesbase << "tcpcurrentconnections "  << "The number of current TCP connections"                             << "\n";
+        output << "# TYPE " << statesbase << "tcpcurrentconnections "  << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpavgqueriesperconn "   << "The average number of queries per TCP connection"                  << "\n";
+        output << "# TYPE " << statesbase << "tcpavgqueriesperconn "   << "gauge"                                                             << "\n";
+        output << "# HELP " << statesbase << "tcpavgconnduration "     << "The average duration of a TCP connection (ms)"                     << "\n";
+        output << "# TYPE " << statesbase << "tcpavgconnduration "     << "gauge"                                                             << "\n";
 
         for (const auto& state : *states) {
           string serverName;
@@ -474,13 +490,21 @@ static void connectionThread(int sock, ComboAddress remote)
           const std::string label = boost::str(boost::format("{server=\"%1%\",address=\"%2%\"}")
             % serverName % state->remote.toStringWithPort());
 
-          output << statesbase << "queries"     << label << " " << state->queries.load()     << "\n";
-          output << statesbase << "drops"       << label << " " << state->reuseds.load()     << "\n";
-          output << statesbase << "latency"     << label << " " << state->latencyUsec/1000.0 << "\n";
-          output << statesbase << "senderrors"  << label << " " << state->sendErrors.load()  << "\n";
-          output << statesbase << "outstanding" << label << " " << state->outstanding.load() << "\n";
-          output << statesbase << "order"       << label << " " << state->order              << "\n";
-          output << statesbase << "weight"      << label << " " << state->weight             << "\n";
+          output << statesbase << "queries"                << label << " " << state->queries.load()             << "\n";
+          output << statesbase << "drops"                  << label << " " << state->reuseds.load()             << "\n";
+          output << statesbase << "latency"                << label << " " << state->latencyUsec/1000.0         << "\n";
+          output << statesbase << "senderrors"             << label << " " << state->sendErrors.load()          << "\n";
+          output << statesbase << "outstanding"            << label << " " << state->outstanding.load()         << "\n";
+          output << statesbase << "order"                  << label << " " << state->order                      << "\n";
+          output << statesbase << "weight"                 << label << " " << state->weight                     << "\n";
+          output << statesbase << "tcpdiedsendingquery"    << label << " " << state->tcpDiedSendingQuery        << "\n";
+          output << statesbase << "tcpdiedreadingresponse" << label << " " << state->tcpDiedReadingResponse     << "\n";
+          output << statesbase << "tcpgaveup"              << label << " " << state->tcpGaveUp                  << "\n";
+          output << statesbase << "tcpreadtimeouts"        << label << " " << state->tcpReadTimeouts            << "\n";
+          output << statesbase << "tcpwritetimeouts"       << label << " " << state->tcpWriteTimeouts           << "\n";
+          output << statesbase << "tcpcurrentconnections"  << label << " " << state->tcpCurrentConnections      << "\n";
+          output << statesbase << "tcpavgqueriesperconn"   << label << " " << state->tcpAvgQueriesPerConnection << "\n";
+          output << statesbase << "tcpavgconnduration"     << label << " " << state->tcpAvgConnectionDuration   << "\n";
         }
 
         for (const auto& front : g_frontends) {
@@ -562,6 +586,14 @@ static void connectionThread(int sock, ComboAddress remote)
           {"latency", (double)(a->latencyUsec/1000.0)},
           {"queries", (double)a->queries},
           {"sendErrors", (double)a->sendErrors},
+          {"tcpDiedSendingQuery", (double)a->tcpDiedSendingQuery},
+          {"tcpDiedReadingResponse", (double)a->tcpDiedReadingResponse},
+          {"tcpGaveUp", (double)a->tcpGaveUp},
+          {"tcpReadTimeouts", (double)a->tcpReadTimeouts},
+          {"tcpWriteTimeouts", (double)a->tcpWriteTimeouts},
+          {"tcpCurrentConnections", (double)a->tcpCurrentConnections},
+          {"tcpAvgQueriesPerConnection", (double)a->tcpAvgQueriesPerConnection},
+          {"tcpAvgConnectionDuration", (double)a->tcpAvgConnectionDuration},
           {"dropRate", (double)a->dropRate}
         };
 
@@ -583,7 +615,16 @@ static void connectionThread(int sock, ComboAddress remote)
           { "address", front->local.toStringWithPort() },
           { "udp", front->udpFD >= 0 },
           { "tcp", front->tcpFD >= 0 },
-          { "queries", (double) front->queries.load() }
+          { "type", front->getType() },
+          { "queries", (double) front->queries.load() },
+          { "tcpDiedReadingQuery", (double) front->tcpDiedReadingQuery.load() },
+          { "tcpDiedSendingResponse", (double) front->tcpDiedSendingResponse.load() },
+          { "tcpGaveUp", (double) front->tcpGaveUp.load() },
+          { "tcpClientTimeouts", (double) front->tcpClientTimeouts },
+          { "tcpDownstreamTimeouts", (double) front->tcpDownstreamTimeouts },
+          { "tcpCurrentConnections", (double) front->tcpCurrentConnections },
+          { "tcpAvgQueriesPerConnection", (double) front->tcpAvgQueriesPerConnection },
+          { "tcpAvgConnectionDuration", (double) front->tcpAvgConnectionDuration },
         };
         frontends.push_back(frontend);
       }
index f2dfa39b218c0f3700345d9c20a9ece48c06d811..461c5f389649257d2ca3a08b69294b1b3bfcfd58 100644 (file)
@@ -94,9 +94,7 @@ string g_outputBuffer;
 
 vector<std::tuple<ComboAddress, bool, bool, int, string, std::set<int>>> g_locals;
 std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
-#ifdef HAVE_DNSCRYPT
 std::vector<std::tuple<ComboAddress,std::shared_ptr<DNSCryptContext>,bool, int, string, std::set<int> >> g_dnsCryptLocals;
-#endif
 #ifdef HAVE_EBPF
 shared_ptr<BPFFilter> g_defaultBPFFilter;
 std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
@@ -156,7 +154,7 @@ try
     hadEDNS = getEDNSUDPPayloadSizeAndZ(packet, *len, &payloadSize, &z);
   }
 
-  *len=(uint16_t) (sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
+  *len=static_cast<uint16_t>(sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
   struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
   dh->ancount = dh->arcount = dh->nscount = 0;
 
@@ -191,7 +189,7 @@ struct DelayedPacket
   }
 };
 
-DelayPipe<DelayedPacket> * g_delay = 0;
+DelayPipe<DelayedPacket>* g_delay = nullptr;
 
 void doLatencyStats(double udiff)
 {
@@ -214,14 +212,11 @@ void doLatencyStats(double udiff)
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed)
 {
-  uint16_t rqtype, rqclass;
-  DNSName rqname;
-  const struct dnsheader* dh = (struct dnsheader*) response;
-
   if (responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response);
   if (dh->qdcount == 0) {
     if ((dh->rcode != RCode::NoError && dh->rcode != RCode::NXDomain) || g_allowEmptyResponse) {
       return true;
@@ -232,12 +227,15 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
     }
   }
 
+  uint16_t rqtype, rqclass;
+  DNSName rqname;
   try {
     rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
   }
-  catch(std::exception& e) {
-    if(responseLen > (ssize_t)sizeof(dnsheader))
+  catch(const std::exception& e) {
+    if(responseLen > 0 && static_cast<size_t>(responseLen) > sizeof(dnsheader)) {
       infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
+    }
     ++g_stats.nonCompliantResponses;
     return false;
   }
@@ -249,7 +247,7 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co
   return true;
 }
 
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
+static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
 {
   static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
   static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
@@ -263,21 +261,20 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
 {
   restoreFlags(dq.dh, origFlags);
 
   return addEDNSToQueryTurnedResponse(dq);
 }
 
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
+static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
 {
-  struct dnsheader* dh = (struct dnsheader*) *response;
-
   if (*responseLen < sizeof(dnsheader)) {
     return false;
   }
 
+  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(*response);
   restoreFlags(dh, origFlags);
 
   if (*responseLen == sizeof(dnsheader)) {
@@ -368,7 +365,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
 }
 
 #ifdef HAVE_DNSCRYPT
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
+static bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
 {
   if (dnsCryptQuery) {
     uint16_t encryptedResponseLen = 0;
@@ -390,9 +387,82 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize,
   }
   return true;
 }
-#endif
+#endif /* HAVE_DNSCRYPT */
+
+static bool applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr)
+{
+  DNSResponseAction::Action action=DNSResponseAction::Action::None;
+  std::string ruleresult;
+  for(const auto& lr : *localRespRulactions) {
+    if(lr.d_rule->matches(&dr)) {
+      lr.d_rule->d_matches++;
+      action=(*lr.d_action)(&dr, &ruleresult);
+      switch(action) {
+      case DNSResponseAction::Action::Allow:
+        return true;
+        break;
+      case DNSResponseAction::Action::Drop:
+        return false;
+        break;
+      case DNSResponseAction::Action::HeaderModify:
+        return true;
+        break;
+      case DNSResponseAction::Action::ServFail:
+        dr.dh->rcode = RCode::ServFail;
+        return true;
+        break;
+        /* non-terminal actions follow */
+      case DNSResponseAction::Action::Delay:
+        dr.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        break;
+      case DNSResponseAction::Action::None:
+        break;
+      }
+    }
+  }
+
+  return true;
+}
 
-static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted)
+{
+  if (!applyRulesToResponse(localRespRulactions, dr)) {
+    return false;
+  }
+
+  bool zeroScope = false;
+  if (!fixUpResponse(response, responseLen, responseSize, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, rewrittenResponse, addRoom, dr.useZeroScope ? &zeroScope : nullptr)) {
+    return false;
+  }
+
+  if (dr.packetCache && !dr.skipCache) {
+    if (!dr.useZeroScope) {
+      /* if the query was not suitable for zero-scope, for
+         example because it had an existing ECS entry so the hash is
+         not really 'no ECS', so just insert it for the existing subnet
+         since:
+         - we don't have the correct hash for a non-ECS query
+         - inserting with hash computed before the ECS replacement but with
+         the subnet extracted _after_ the replacement would not work.
+      */
+      zeroScope = false;
+    }
+    // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
+    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, *response, *responseLen, dr.tcp, dr.dh->rcode, dr.tempFailureTTL);
+  }
+
+#ifdef HAVE_DNSCRYPT
+  if (!muted) {
+    if (!encryptResponse(*response, responseLen, *responseSize, dr.tcp, dr.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
+    }
+  }
+#endif /* HAVE_DNSCRYPT */
+
+  return true;
+}
+
+static bool sendUDPResponse(int origFD, const char* response, const uint16_t responseLen, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   if(delayMsec && g_delay) {
     DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
@@ -401,7 +471,7 @@ static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, in
   else {
     ssize_t res;
     if(origDest.sin4.sin_family == 0) {
-      res = sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen());
+      res = sendto(origFD, response, responseLen, 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
     }
     else {
       res = sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
@@ -416,7 +486,7 @@ static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, in
 }
 
 
-static int pickBackendSocketForSending(DownstreamState* state)
+static int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
 {
   return state->sockets[state->socketsOffset++ % state->sockets.size()];
 }
@@ -441,15 +511,11 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 try {
   setThreadName("dnsdist/respond");
   auto localRespRulactions = g_resprulactions.getLocal();
-#ifdef HAVE_DNSCRYPT
   char packet[4096 + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE];
+  static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
   /* when the answer is encrypted in place, we need to get a copy
      of the original header before encryption to fill the ring buffer */
-  dnsheader dhCopy;
-#else
-  char packet[4096];
-#endif
-  static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
+  dnsheader cleartextDH;
   vector<uint8_t> rewrittenResponse;
 
   uint16_t queryId = 0;
@@ -465,10 +531,10 @@ try {
         char * response = packet;
         size_t responseSize = sizeof(packet);
 
-        if (got < (ssize_t) sizeof(dnsheader))
+        if (got < 0 || static_cast<size_t>(got) < sizeof(dnsheader))
           continue;
 
-        uint16_t responseLen = (uint16_t) got;
+        uint16_t responseLen = static_cast<uint16_t>(got);
         queryId = dh->id;
 
         if(queryId >= dss->idStates.size())
@@ -508,80 +574,49 @@ try {
         dh->id = ids->origID;
 
         uint16_t addRoom = 0;
-        DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = ids->uniqueId;
-#endif
-        dr.qTag = ids->qTag;
-
-        if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
-          continue;
-        }
-
-#ifdef HAVE_DNSCRYPT
-        if (ids->dnsCryptQuery) {
+        DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false);
+        if (dr.dnsCryptQuery) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
         }
-#endif
-        bool zeroScope = false;
-        if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, ids->ecsAdded, rewrittenResponse, addRoom, ids->useZeroScope ? &zeroScope : nullptr)) {
-          continue;
-        }
 
-        if (ids->packetCache && !ids->skipCache) {
-          if (!ids->useZeroScope) {
-            /* if the query was not suitable for zero-scope, for
-               example because it had an existing ECS entry so the hash is
-               not really 'no ECS', so just insert it for the existing subnet
-               since:
-               - we don't have the correct hash for a non-ECS query
-               - inserting with hash computed before the ECS replacement but with
-               the subnet extracted _after_ the replacement would not work.
-            */
-            zeroScope = false;
-          }
-          // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-          ids->packetCache->insert(zeroScope ? ids->cacheKeyNoECS : ids->cacheKey, zeroScope ? boost::none : ids->subnet, ids->origFlags, ids->dnssecOK, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
+        memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
+        if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, ids->cs && ids->cs->muted)) {
+          continue;
         }
 
         if (ids->cs && !ids->cs->muted) {
-#ifdef HAVE_DNSCRYPT
-          if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery, &dh, &dhCopy)) {
-            continue;
-          }
-#endif
-
           ComboAddress empty;
           empty.sin4.sin_family = 0;
           /* if ids->destHarvested is false, origDest holds the listening address.
              We don't want to use that as a source since it could be 0.0.0.0 for example. */
-          sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
+          sendUDPResponse(origFD, response, responseLen, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
         }
 
         ++g_stats.responses;
 
         double udiff = ids->sentTime.udiff();
-        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
+        vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), dr.remote->toStringWithPort(), udiff);
 
         struct timespec ts;
         gettime(&ts);
-        g_rings.insertResponse(ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, dss->remote);
+        g_rings.insertResponse(ts, *dr.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(got), cleartextDH, dss->remote);
 
-        if(dh->rcode == RCode::ServFail) {
+        switch (dh->rcode) {
+        case RCode::NXDomain:
+          ++g_stats.frontendNXDomain;
+          break;
+        case RCode::ServFail:
           ++g_stats.servfailResponses;
+          ++g_stats.frontendServFail;
+          break;
+        case RCode::NoError:
+          ++g_stats.frontendNoError;
+          break;
         }
         dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
 
         doLatencyStats(udiff);
 
-        /* if the FD is not -1, the state has been actively reused and we should
-           not alter anything */
-        if (ids->origFD == -1) {
-#ifdef HAVE_DNSCRYPT
-          ids->dnsCryptQuery = nullptr;
-#endif
-        }
-
         rewrittenResponse.clear();
       }
     }
@@ -962,9 +997,9 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
   }
 }
 
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
+static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, const struct timespec& now)
 {
-  g_rings.insertQuery(now,*dq.remote,*dq.qname,dq.qtype,dq.len,*dq.dh);
+  g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
 
   if(g_qcount.enabled) {
     string qname = (*dq.qname).toString(".");
@@ -1146,7 +1181,7 @@ bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int*
         break;
         /* non-terminal actions follow */
       case DNSAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+        dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
         break;
       case DNSAction::Action::None:
         /* fall-through */
@@ -1163,42 +1198,7 @@ bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int*
   return true;
 }
 
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec)
-{
-  DNSResponseAction::Action action=DNSResponseAction::Action::None;
-  std::string ruleresult;
-  for(const auto& lr : *localRespRulactions) {
-    if(lr.d_rule->matches(&dr)) {
-      lr.d_rule->d_matches++;
-      action=(*lr.d_action)(&dr, &ruleresult);
-      switch(action) {
-      case DNSResponseAction::Action::Allow:
-        return true;
-        break;
-      case DNSResponseAction::Action::Drop:
-        return false;
-        break;
-      case DNSResponseAction::Action::HeaderModify:
-        return true;
-        break;
-      case DNSResponseAction::Action::ServFail:
-        dr.dh->rcode = RCode::ServFail;
-        return true;
-        break;
-        /* non-terminal actions follow */
-      case DNSResponseAction::Action::Delay:
-        *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
-        break;
-      case DNSResponseAction::Action::None:
-        break;
-      }
-    }
-  }
-
-  return true;
-}
-
-static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
+static ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
 {
   ssize_t result;
 
@@ -1261,29 +1261,29 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
   return true;
 }
 
-#ifdef HAVE_DNSCRYPT
-static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now)
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
 {
   if (cs.dnscryptCtx) {
+#ifdef HAVE_DNSCRYPT
     vector<uint8_t> response;
     uint16_t decryptedQueryLen = 0;
 
     dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
 
-    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response);
+    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response);
 
     if (!decrypted) {
       if (response.size() > 0) {
-        sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(response.data()), static_cast<uint16_t>(response.size()), 0, dest, remote);
+        return response;
       }
-      return false;
+      throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
     }
 
     len = decryptedQueryLen;
+#endif /* HAVE_DNSCRYPT */
   }
-  return true;
+  return boost::none;
 }
-#endif /* HAVE_DNSCRYPT */
 
 bool checkQueryHeaders(const struct dnsheader* dh)
 {
@@ -1319,109 +1319,82 @@ static void queueResponse(const ClientState& cs, const char* response, uint16_t
 }
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
 
-static int sendAndEncryptUDPResponse(LocalHolders& holders, ClientState& cs, const DNSQuestion& dq, char* response, uint16_t responseLen, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, int delayMsec, const ComboAddress& dest, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf, bool cacheHit)
+/* self-generated responses or cache hits */
+static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
 {
-  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, dq.queryTime);
+  DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(dq.dh), dq.size, dq.len, dq.tcp, dq.queryTime);
+
 #ifdef HAVE_PROTOBUF
   dr.uniqueId = dq.uniqueId;
 #endif
   dr.qTag = dq.qTag;
+  dr.delayMsec = dq.delayMsec;
 
-  if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
-    return -1;
+  if (!applyRulesToResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr)) {
+    return false;
   }
 
-  if (!cs.muted) {
+  /* in case a rule changed it */
+  dq.delayMsec = dr.delayMsec;
+
 #ifdef HAVE_DNSCRYPT
-    if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
-      return -1;
-    }
-#endif
-#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
-    if (delayMsec == 0 && responsesVect != nullptr) {
-      queueResponse(cs, response, responseLen, dest, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
-      (*queuedResponses)++;
+  if (!cs.muted) {
+    if (!encryptResponse(reinterpret_cast<char*>(dq.dh), &dq.len, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
+      return false;
     }
-    else
-#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-      {
-        sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, *dq.remote);
-      }
   }
+#endif /* HAVE_DNSCRYPT */
 
   if (cacheHit) {
     ++g_stats.cacheHits;
   }
+
+  switch (dr.dh->rcode) {
+  case RCode::NXDomain:
+    ++g_stats.frontendNXDomain;
+    break;
+  case RCode::ServFail:
+    ++g_stats.frontendServFail;
+    break;
+  case RCode::NoError:
+    ++g_stats.frontendNoError;
+    break;
+  }
+
   doLatencyStats(0);  // we're not going to measure this
-  return 0;
+  return true;
 }
 
-static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
-  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
-  uint16_t queryId = 0;
+  const uint16_t queryId = ntohs(dq.dh->id);
 
   try {
-    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
-      return;
-    }
-
     /* we need an accurate ("real") value for the response and
        to store into the IDS, but not for insertion into the
        rings for example */
-    struct timespec queryRealTime;
     struct timespec now;
     gettime(&now);
-    gettime(&queryRealTime, true);
-
-    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-#ifdef HAVE_DNSCRYPT
-    if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) {
-      return;
-    }
-#endif
-
-    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
-    queryId = ntohs(dh->id);
-
-    if (!checkQueryHeaders(dh)) {
-      return;
-    }
 
     string poolname;
-    int delayMsec = 0;
-    const uint16_t * flags = getFlagsFromDNSHeader(dh);
-    const uint16_t origFlags = *flags;
-    uint16_t qtype, qclass;
-    unsigned int consumed = 0;
-    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
-    bool dnssecOK = false;
 
-    if (!processQuery(holders, dq, poolname, &delayMsec, now))
-    {
-      return;
+    if (!applyRulesToQuery(holders, dq, poolname, now)) {
+      return ProcessQueryResult::Drop;
     }
 
     if(dq.dh->qr) { // something turned it into a response
-      fixUpQueryTurnedResponse(dq, origFlags);
-
-      if (!cs.muted) {
-        char* response = query;
-        uint16_t responseLen = dq.len;
+      fixUpQueryTurnedResponse(dq, dq.origFlags);
 
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
-
-        ++g_stats.selfAnswered;
+      if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+        return ProcessQueryResult::Drop;
       }
 
-      return;
+      ++g_stats.selfAnswered;
+      return ProcessQueryResult::SendAnswer;
     }
 
-    DownstreamState* ss = nullptr;
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
-    std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+    dq.packetCache = serverPool->packetCache;
     auto policy = *(holders.policy);
     if (serverPool->policy != nullptr) {
       policy = *(serverPool->policy);
@@ -1429,77 +1402,148 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     auto servers = serverPool->getServers();
     if (policy.isLua) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      ss = policy.policy(servers, &dq).get();
+      selectedBackend = policy.policy(servers, &dq);
     }
     else {
-      ss = policy.policy(servers, &dq).get();
+      selectedBackend = policy.policy(servers, &dq);
     }
 
-    bool ednsAdded = false;
-    bool ecsAdded = false;
-    uint32_t cacheKeyNoECS = 0;
-    uint32_t cacheKey = 0;
-    boost::optional<Netmask> subnet;
     uint16_t cachedResponseSize = dq.size;
-    uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
-    bool useZeroScope = false;
+    uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
 
-    if (packetCache && !dq.skipCache) {
-      dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
+    if (dq.packetCache && !dq.skipCache) {
+      dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
     }
 
-    if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
+    if (dq.useECS && ((selectedBackend && selectedBackend->useECS) || (!selectedBackend && serverPool->getECS()))) {
       // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
-      if (packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && packetCache->isECSParsingEnabled()) {
-        if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
-          sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-          return;
+      if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
+        if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
+          dq.len = cachedResponseSize;
+
+          if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+            return ProcessQueryResult::Drop;
+          }
+
+          return ProcessQueryResult::SendAnswer;
         }
 
-        if (!subnet) {
+        if (!dq.subnet) {
           /* there was no existing ECS on the query, enable the zero-scope feature */
-          useZeroScope = true;
+          dq.useZeroScope = true;
         }
       }
 
-      if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
-        vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
-        return;
+      if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
+        vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
+        return ProcessQueryResult::Drop;
       }
     }
 
-    if (packetCache && !dq.skipCache) {
-      if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
-        sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
-        return;
+    if (dq.packetCache && !dq.skipCache) {
+      if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
+        dq.len = cachedResponseSize;
+
+        if (!prepareOutgoingResponse(holders, cs, dq, true)) {
+          return ProcessQueryResult::Drop;
+        }
+
+        return ProcessQueryResult::SendAnswer;
       }
       ++g_stats.cacheMisses;
     }
 
-    if(!ss) {
+    if(!selectedBackend) {
       ++g_stats.noPolicy;
 
-      if (g_servFailOnNoPolicy && !cs.muted) {
-        char* response = query;
-        uint16_t responseLen = dq.len;
-        restoreFlags(dh, origFlags);
+      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+      if (g_servFailOnNoPolicy) {
+        restoreFlags(dq.dh, dq.origFlags);
 
         dq.dh->rcode = RCode::ServFail;
         dq.dh->qr = true;
 
-        sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
-
+        if (!prepareOutgoingResponse(holders, cs, dq, false)) {
+          return ProcessQueryResult::Drop;
+        }
         // no response-only statistics counter to update.
+        return ProcessQueryResult::SendAnswer;
       }
-      vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
+
+      return ProcessQueryResult::Drop;
+    }
+
+    if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
+      addXPF(dq, selectedBackend->xpfRRCode, g_preserveTrailingData);
+    }
+
+    selectedBackend->queries++;
+    return ProcessQueryResult::PassToBackend;
+  }
+  catch(const std::exception& e){
+    vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
+  }
+  return ProcessQueryResult::Drop;
+}
+
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+{
+  assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
+  uint16_t queryId = 0;
+
+  try {
+    if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
       return;
     }
 
-    if (dq.addXPF && ss->xpfRRCode != 0) {
-      addXPF(dq, ss->xpfRRCode, g_preserveTrailingData);
+    /* we need an accurate ("real") value for the response and
+       to store into the IDS, but not for insertion into the
+       rings for example */
+    struct timespec queryRealTime;
+    gettime(&queryRealTime, true);
+
+    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
+    auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
+    if (dnsCryptResponse) {
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), 0, dest, remote);
+      return;
     }
 
-    ss->queries++;
+    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+    queryId = ntohs(dh->id);
+
+    if (!checkQueryHeaders(dh)) {
+      return;
+    }
+
+    uint16_t qtype, qclass;
+    unsigned int consumed = 0;
+    DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+    DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
+    dq.dnsCryptQuery = std::move(dnsCryptQuery);
+    std::shared_ptr<DownstreamState> ss{nullptr};
+    auto result = processQuery(dq, cs, holders, ss);
+
+    if (result == ProcessQueryResult::Drop) {
+      return;
+    }
+
+    if (result == ProcessQueryResult::SendAnswer) {
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+      if (dq.delayMsec == 0 && responsesVect != nullptr) {
+        queueResponse(cs, reinterpret_cast<char*>(dq.dh), dq.len, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+        (*queuedResponses)++;
+        return;
+      }
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+      /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
+      sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dq.dh), dq.len, dq.delayMsec, dest, *dq.remote);
+      return;
+    }
+
+    if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
+      return;
+    }
 
     unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
     IDState* ids = &ss->idStates[idOffset];
@@ -1517,24 +1561,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     ids->cs = &cs;
     ids->origID = dh->id;
-    ids->origRemote = remote;
-    ids->sentTime.set(queryRealTime);
-    ids->qname = qname;
-    ids->qtype = dq.qtype;
-    ids->qclass = dq.qclass;
-    ids->delayMsec = delayMsec;
-    ids->tempFailureTTL = dq.tempFailureTTL;
-    ids->origFlags = origFlags;
-    ids->cacheKey = cacheKey;
-    ids->cacheKeyNoECS = cacheKeyNoECS;
-    ids->subnet = subnet;
-    ids->skipCache = dq.skipCache;
-    ids->packetCache = packetCache;
-    ids->ednsAdded = ednsAdded;
-    ids->ecsAdded = ecsAdded;
-    ids->useZeroScope = useZeroScope;
-    ids->qTag = dq.qTag;
-    ids->dnssecOK = dnssecOK;
+    setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
     /* If we couldn't harvest the real dest addr, still
        write down the listening addr since it will be useful
@@ -1550,12 +1577,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids->origDest = cs.local;
       ids->destHarvested = false;
     }
-#ifdef HAVE_DNSCRYPT
-    ids->dnsCryptQuery = dnsCryptQuery;
-#endif
-#ifdef HAVE_PROTOBUF
-    ids->uniqueId = dq.uniqueId;
-#endif
 
     dh->id = idOffset;
 
@@ -1631,8 +1652,8 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
       unsigned int got = msgVec[msgIdx].msg_len;
       const ComboAddress& remote = recvData[msgIdx].remote;
 
-      if (got < sizeof(struct dnsheader)) {
-        g_stats.nonCompliantQueries++;
+      if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
+        ++g_stats.nonCompliantQueries;
         continue;
       }
 
@@ -1721,12 +1742,12 @@ uint16_t getRandomDNSID()
 #endif
 }
 
-static bool upCheck(DownstreamState& ds)
+static bool upCheck(const shared_ptr<DownstreamState>& ds)
 try
 {
-  DNSName checkName = ds.checkName;
-  uint16_t checkType = ds.checkType.getCode();
-  uint16_t checkClass = ds.checkClass;
+  DNSName checkName = ds->checkName;
+  uint16_t checkType = ds->checkType.getCode();
+  uint16_t checkClass = ds->checkClass;
   dnsheader checkHeader;
   memset(&checkHeader, 0, sizeof(checkHeader));
 
@@ -1734,13 +1755,13 @@ try
   checkHeader.id = getRandomDNSID();
 
   checkHeader.rd = true;
-  if (ds.setCD) {
+  if (ds->setCD) {
     checkHeader.cd = true;
   }
 
-  if (ds.checkFunction) {
+  if (ds->checkFunction) {
     std::lock_guard<std::mutex> lock(g_luamutex);
-    auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader);
+    auto ret = ds->checkFunction(checkName, checkType, checkClass, &checkHeader);
     checkName = std::get<0>(ret);
     checkType = std::get<1>(ret);
     checkClass = std::get<2>(ret);
@@ -1751,31 +1772,31 @@ try
   dnsheader * requestHeader = dpw.getHeader();
   *requestHeader = checkHeader;
 
-  Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM);
+  Socket sock(ds->remote.sin4.sin_family, SOCK_DGRAM);
   sock.setNonBlocking();
-  if (!IsAnyAddress(ds.sourceAddr)) {
+  if (!IsAnyAddress(ds->sourceAddr)) {
     sock.setReuseAddr();
-    sock.bind(ds.sourceAddr);
+    sock.bind(ds->sourceAddr);
   }
-  sock.connect(ds.remote);
-  ssize_t sent = udpClientSendRequestToBackend(&ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
+  sock.connect(ds->remote);
+  ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), reinterpret_cast<char*>(&packet[0]), packet.size(), true);
   if (sent < 0) {
     int ret = errno;
     if (g_verboseHealthChecks)
-      infolog("Error while sending a health check query to backend %s: %d", ds.getNameWithAddr(), ret);
+      infolog("Error while sending a health check query to backend %s: %d", ds->getNameWithAddr(), ret);
     return false;
   }
 
-  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds.checkTimeout / 1000, /* remaining ms to us */ (ds.checkTimeout % 1000) * 1000);
+  int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds->checkTimeout / 1000, /* remaining ms to us */ (ds->checkTimeout % 1000) * 1000);
   if(ret < 0 || !ret) { // error, timeout, both are down!
     if (ret < 0) {
       ret = errno;
       if (g_verboseHealthChecks)
-        infolog("Error while waiting for the health check response from backend %s: %d", ds.getNameWithAddr(), ret);
+        infolog("Error while waiting for the health check response from backend %s: %d", ds->getNameWithAddr(), ret);
     }
     else {
       if (g_verboseHealthChecks)
-        infolog("Timeout while waiting for the health check response from backend %s", ds.getNameWithAddr());
+        infolog("Timeout while waiting for the health check response from backend %s", ds->getNameWithAddr());
     }
     return false;
   }
@@ -1785,9 +1806,9 @@ try
   sock.recvFrom(reply, from);
 
   /* we are using a connected socket but hey.. */
-  if (from != ds.remote) {
+  if (from != ds->remote) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds.remote.toStringWithPort());
+      infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds->remote.toStringWithPort());
     return false;
   }
 
@@ -1795,31 +1816,31 @@ try
 
   if (reply.size() < sizeof(*responseHeader)) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds.getNameWithAddr(), sizeof(*responseHeader));
+      infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds->getNameWithAddr(), sizeof(*responseHeader));
     return false;
   }
 
   if (responseHeader->id != requestHeader->id) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds.getNameWithAddr(), requestHeader->id);
+      infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds->getNameWithAddr(), requestHeader->id);
     return false;
   }
 
   if (!responseHeader->qr) {
     if (g_verboseHealthChecks)
-      infolog("Invalid health check response from backend %s, expecting QR to be set", ds.getNameWithAddr());
+      infolog("Invalid health check response from backend %s, expecting QR to be set", ds->getNameWithAddr());
     return false;
   }
 
   if (responseHeader->rcode == RCode::ServFail) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with ServFail", ds.getNameWithAddr());
+      infolog("Backend %s responded to health check with ServFail", ds->getNameWithAddr());
     return false;
   }
 
-  if (ds.mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
+  if (ds->mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with %s while mustResolve is set", ds.getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
+      infolog("Backend %s responded to health check with %s while mustResolve is set", ds->getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
     return false;
   }
 
@@ -1829,7 +1850,7 @@ try
 
   if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
     if (g_verboseHealthChecks)
-      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
+      infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds->getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
     return false;
   }
 
@@ -1838,13 +1859,13 @@ try
 catch(const std::exception& e)
 {
   if (g_verboseHealthChecks)
-    infolog("Error checking the health of backend %s: %s", ds.getNameWithAddr(), e.what());
+    infolog("Error checking the health of backend %s: %s", ds->getNameWithAddr(), e.what());
   return false;
 }
 catch(...)
 {
   if (g_verboseHealthChecks)
-    infolog("Unknown exception while checking the health of backend %s", ds.getNameWithAddr());
+    infolog("Unknown exception while checking the health of backend %s", ds->getNameWithAddr());
   return false;
 }
 
@@ -1958,15 +1979,15 @@ static void healthChecksThread()
         continue;
       dss->lastCheck = 0;
       if(dss->availability==DownstreamState::Availability::Auto) {
-        bool newState=upCheck(*dss);
+        bool newState=upCheck(dss);
         if (newState) {
           /* check succeeded */
           dss->currentCheckFailures = 0;
 
           if (!dss->upStatus) {
             /* we were marked as down */
-            dss->consecutiveSuccesfulChecks++;
-            if (dss->consecutiveSuccesfulChecks < dss->minRiseSuccesses) {
+            dss->consecutiveSuccessfulChecks++;
+            if (dss->consecutiveSuccessfulChecks < dss->minRiseSuccesses) {
               /* if we need more than one successful check to rise
                  and we didn't reach the threshold yet,
                  let's stay down */
@@ -1976,7 +1997,7 @@ static void healthChecksThread()
         }
         else {
           /* check failed */
-          dss->consecutiveSuccesfulChecks = 0;
+          dss->consecutiveSuccessfulChecks = 0;
 
           if (dss->upStatus) {
             /* we are currently up */
@@ -2002,7 +2023,7 @@ static void healthChecksThread()
 
           dss->upStatus = newState;
           dss->currentCheckFailures = 0;
-          dss->consecutiveSuccesfulChecks = 0;
+          dss->consecutiveSuccessfulChecks = 0;
           if (g_snmpAgent && g_snmpTrapsEnabled) {
             g_snmpAgent->sendBackendStatusChangeTrap(dss);
           }
@@ -2206,7 +2227,7 @@ try
 
   signal(SIGPIPE, SIG_IGN);
   signal(SIGCHLD, SIG_IGN);
-  openlog("dnsdist", LOG_PID, LOG_DAEMON);
+  openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON);
 
 #ifdef HAVE_LIBSODIUM
   if (sodium_init() == -1) {
@@ -2330,6 +2351,9 @@ try
 #ifdef HAVE_FSTRM
       cout<<"fstrm ";
 #endif
+#ifdef HAVE_LIBCRYPTO
+      cout<<"ipcipher ";
+#endif
 #ifdef HAVE_LIBSODIUM
       cout<<"libsodium ";
 #endif
@@ -2555,7 +2579,6 @@ try
     tcpBindsCount++;
   }
 
-#ifdef HAVE_DNSCRYPT
   for(auto& dcLocal : g_dnsCryptLocals) {
     ClientState* cs = new ClientState;
     cs->local = std::get<0>(dcLocal);
@@ -2657,7 +2680,6 @@ try
     g_frontends.push_back(cs);
     tcpBindsCount++;
   }
-#endif
 
   for(auto& frontend : g_tlslocals) {
     ClientState* cs = new ClientState;
@@ -2774,7 +2796,7 @@ try
     g_snmpAgent->run();
   }
 
-  g_tcpclientthreads = std::make_shared<TCPClientCollection>(g_maxTCPClientThreads, g_useTCPSinglePipe);
+  g_tcpclientthreads = std::unique_ptr<TCPClientCollection>(new TCPClientCollection(g_maxTCPClientThreads, g_useTCPSinglePipe));
 
   for(auto& t : todo)
     t();
@@ -2803,7 +2825,7 @@ try
 
   for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
     if(dss->availability==DownstreamState::Availability::Auto) {
-      bool newState=upCheck(*dss);
+      bool newState=upCheck(dss);
       warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
       dss->upStatus = newState;
     }
index 925d1d0d1d639218a1c44339b06d6abbf72e13a4..663e18061bdb115b3deaef847b91d12ef1fa5d5e 100644 (file)
@@ -61,33 +61,47 @@ typedef std::unordered_map<string, string> QTag;
 struct DNSQuestion
 {
   DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_):
-    qname(name), qtype(type), qclass(class_), local(lc), remote(rem), dh(header), size(bufferSize), consumed(consumed_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tempFailureTTL(boost::none), tcp(isTcp), queryTime(queryTime_), ecsOverride(g_ECSOverride) { }
+    qname(name), local(lc), remote(rem), dh(header), queryTime(queryTime_), size(bufferSize), consumed(consumed_), tempFailureTTL(boost::none), qtype(type), qclass(class_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) {
+    const uint16_t* flags = getFlagsFromDNSHeader(dh);
+    origFlags = *flags;
+  }
 
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
   Netmask ecs;
-  const DNSName* qname;
-  const uint16_t qtype;
-  const uint16_t qclass;
-  const ComboAddress* local;
-  const ComboAddress* remote;
+  boost::optional<Netmask> subnet;
+  const DNSName* qname{nullptr};
+  const ComboAddress* local{nullptr};
+  const ComboAddress* remote{nullptr};
   std::shared_ptr<QTag> qTag{nullptr};
   std::shared_ptr<std::map<uint16_t, EDNSOptionView> > ednsOptions;
-  struct dnsheader* dh;
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
+  struct dnsheader* dh{nullptr};
+  const struct timespec* queryTime{nullptr};
   size_t size;
   unsigned int consumed{0};
+  int delayMsec{0};
+  boost::optional<uint32_t> tempFailureTTL;
+  uint32_t cacheKeyNoECS;
+  uint32_t cacheKey;
+  const uint16_t qtype;
+  const uint16_t qclass;
   uint16_t len;
   uint16_t ecsPrefixLength;
+  uint16_t origFlags;
   uint8_t ednsRCode{0};
-  boost::optional<uint32_t> tempFailureTTL;
   const bool tcp;
-  const struct timespec* queryTime;
   bool skipCache{false};
   bool ecsOverride;
   bool useECS{true};
   bool addXPF{true};
   bool ecsSet{false};
+  bool ecsAdded{false};
+  bool ednsAdded{false};
+  bool useZeroScope{false};
+  bool dnssecOK{false};
 };
 
 struct DNSResponse : DNSQuestion
@@ -208,6 +222,9 @@ struct DNSDistStats
   stat_t responses{0};
   stat_t servfailResponses{0};
   stat_t queries{0};
+  stat_t frontendNXDomain{0};
+  stat_t frontendServFail{0};
+  stat_t frontendNoError{0};
   stat_t nonCompliantQueries{0};
   stat_t nonCompliantResponses{0};
   stat_t rdQueries{0};
@@ -235,6 +252,9 @@ struct DNSDistStats
     {"responses", &responses},
     {"servfail-responses", &servfailResponses},
     {"queries", &queries},
+    {"frontend-nxdomain", &frontendNXDomain},
+    {"frontend-servfail", &frontendServFail},
+    {"frontend-noerror", &frontendNoError},
     {"acl-drops", &aclDrops},
     {"rule-drop", &ruleDrop},
     {"rule-nxdomain", &ruleNXDomain},
@@ -324,6 +344,9 @@ struct MetricDefinitionStorage {
     { "responses",              MetricDefinition(PrometheusMetricType::counter, "Number of responses received from backends") },
     { "servfail-responses",     MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers received from backends") },
     { "queries",                MetricDefinition(PrometheusMetricType::counter, "Number of received queries")},
+    { "frontend-nxdomain",      MetricDefinition(PrometheusMetricType::counter, "Number of NXDomain answers sent to clients")},
+    { "frontend-servfail",      MetricDefinition(PrometheusMetricType::counter, "Number of SERVFAIL answers sent to clients")},
+    { "frontend-noerror",       MetricDefinition(PrometheusMetricType::counter, "Number of NoError answers sent to clients")},
     { "acl-drops",              MetricDefinition(PrometheusMetricType::counter, "Number of packets dropped because of the ACL")},
     { "rule-drop",              MetricDefinition(PrometheusMetricType::counter, "Number of queries dropped because of a rule")},
     { "rule-nxdomain",          MetricDefinition(PrometheusMetricType::counter, "Number of NXDomain answers returned because of a rule")},
@@ -516,9 +539,7 @@ struct IDState
   ComboAddress origDest;                                      // 28
   StopWatch sentTime;                                         // 16
   DNSName qname;                                              // 80
-#ifdef HAVE_DNSCRYPT
   std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
-#endif
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
 #endif
@@ -562,11 +583,18 @@ struct ClientState
 {
   std::set<int> cpus;
   ComboAddress local;
-#ifdef HAVE_DNSCRYPT
   std::shared_ptr<DNSCryptContext> dnscryptCtx{nullptr};
-#endif
   shared_ptr<TLSFrontend> tlsFrontend;
   std::atomic<uint64_t> queries{0};
+  std::atomic<uint64_t> tcpDiedReadingQuery{0};
+  std::atomic<uint64_t> tcpDiedSendingResponse{0};
+  std::atomic<uint64_t> tcpGaveUp{0};
+  std::atomic<uint64_t> tcpClientTimeouts{0};
+  std::atomic<uint64_t> tcpDownstreamTimeouts{0};
+  std::atomic<uint64_t> tcpCurrentConnections{0};
+  std::atomic<double> tcpAvgQueriesPerConnection{0.0};
+  /* in ms */
+  std::atomic<double> tcpAvgConnectionDuration{0.0};
   int udpFD{-1};
   int tcpFD{-1};
   bool muted{false};
@@ -576,6 +604,20 @@ struct ClientState
     return udpFD != -1 ? udpFD : tcpFD;
   }
 
+  std::string getType() const
+  {
+    std::string result = udpFD != -1 ? "UDP" : "TCP";
+
+    if (tlsFrontend) {
+      result += " (DNS over TLS)";
+    }
+    else if (dnscryptCtx) {
+      result += " (DNSCrypt)";
+    }
+
+    return result;
+  }
+
 #ifdef HAVE_EBPF
   shared_ptr<BPFFilter> d_filter;
 
@@ -595,6 +637,12 @@ struct ClientState
     d_filter = bpf;
   }
 #endif /* HAVE_EBPF */
+
+  void updateTCPMetrics(size_t queries, uint64_t durationMs)
+  {
+    tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (queries / 100.0);
+    tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
+  }
 };
 
 class TCPClientCollection {
@@ -617,6 +665,14 @@ public:
       if (pipe(d_singlePipe) < 0) {
         throw std::runtime_error("Error creating the TCP single communication pipe: " + string(strerror(errno)));
       }
+
+      if (!setNonBlocking(d_singlePipe[0])) {
+        int err = errno;
+        close(d_singlePipe[0]);
+        close(d_singlePipe[1]);
+        throw std::runtime_error("Error setting the TCP single communication pipe non-blocking: " + string(strerror(err)));
+      }
+
       if (!setNonBlocking(d_singlePipe[1])) {
         int err = errno;
         close(d_singlePipe[0]);
@@ -650,7 +706,7 @@ public:
   void addTCPClientThread();
 };
 
-extern std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+extern std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
 
 struct DownstreamState
 {
@@ -693,6 +749,15 @@ struct DownstreamState
     std::atomic<uint64_t> reuseds{0};
     std::atomic<uint64_t> queries{0};
   } prev;
+  std::atomic<uint64_t> tcpDiedSendingQuery{0};
+  std::atomic<uint64_t> tcpDiedReadingResponse{0};
+  std::atomic<uint64_t> tcpGaveUp{0};
+  std::atomic<uint64_t> tcpReadTimeouts{0};
+  std::atomic<uint64_t> tcpWriteTimeouts{0};
+  std::atomic<uint64_t> tcpCurrentConnections{0};
+  std::atomic<double> tcpAvgQueriesPerConnection{0.0};
+  /* in ms */
+  std::atomic<double> tcpAvgConnectionDuration{0.0};
   string name;
   size_t socketsOffset{0};
   double queryLoad{0.0};
@@ -710,7 +775,7 @@ struct DownstreamState
   uint16_t xpfRRCode{0};
   uint16_t checkTimeout{1000}; /* in milliseconds */
   uint8_t currentCheckFailures{0};
-  uint8_t consecutiveSuccesfulChecks{0};
+  uint8_t consecutiveSuccessfulChecks{0};
   uint8_t maxCheckFailures{1};
   uint8_t minRiseSuccesses{1};
   StopWatch sw;
@@ -764,6 +829,12 @@ struct DownstreamState
   void hash();
   void setId(const boost::uuids::uuid& newId);
   void setWeight(int newWeight);
+
+  void updateTCPMetrics(size_t queries, uint64_t durationMs)
+  {
+    tcpAvgQueriesPerConnection = (99.0 * tcpAvgQueriesPerConnection / 100.0) + (queries / 100.0);
+    tcpAvgConnectionDuration = (99.0 * tcpAvgConnectionDuration / 100.0) + (durationMs / 100.0);
+  }
 };
 using servers_t =vector<std::shared_ptr<DownstreamState>>;
 
@@ -959,7 +1030,7 @@ extern std::string g_apiConfigDirectory;
 extern bool g_servFailOnNoPolicy;
 extern uint32_t g_hashperturb;
 extern bool g_useTCPSinglePipe;
-extern std::atomic<uint16_t> g_downstreamTCPCleanupInterval;
+extern uint16_t g_downstreamTCPCleanupInterval;
 extern size_t g_udpVectorSize;
 extern bool g_preserveTrailingData;
 extern bool g_allowEmptyResponse;
@@ -1023,19 +1094,13 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n
 void resetLuaSideEffect(); // reset to indeterminate state
 
 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed);
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now);
-bool processResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, int* delayMsec);
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags);
-bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope);
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
+bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted);
+
 bool checkQueryHeaders(const struct dnsheader* dh);
 
-#ifdef HAVE_DNSCRYPT
 extern std::vector<std::tuple<ComboAddress, std::shared_ptr<DNSCryptContext>, bool, int, std::string, std::set<int> > > g_dnsCryptLocals;
-
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
 int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response);
-#endif
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp);
 
 bool addXPF(DNSQuestion& dq, uint16_t optionCode);
 
@@ -1049,3 +1114,9 @@ extern DNSDistSNMPAgent* g_snmpAgent;
 extern bool g_addEDNSToSelfGeneratedResponses;
 
 static const size_t s_udpIncomingBufferSize{1500};
+
+enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend };
+ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
+
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP);
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname);
index ceeb6d04048aef9f151cb51b60be921855be2a60..3c438681454a1f657f410cf7d41b5ae88045851d 100644 (file)
@@ -2,7 +2,8 @@ AM_CPPFLAGS += $(SYSTEMD_CFLAGS) $(LUA_CFLAGS) $(LIBEDIT_CFLAGS) $(LIBSODIUM_CFL
 
 ACLOCAL_AMFLAGS = -I m4
 
-SUBDIRS=ext/yahttp
+SUBDIRS=ext/ipcrypt \
+       ext/yahttp
 
 CLEANFILES = \
        dnsmessage.pb.cc \
@@ -44,6 +45,10 @@ AM_CPPFLAGS += $(GNUTLS_CFLAGS)
 endif
 endif
 
+if HAVE_LIBCRYPTO
+AM_CPPFLAGS += $(LIBCRYPTO_INCLUDES)
+endif
+
 EXTRA_DIST=COPYING \
           dnslabeltext.rl \
           dnsdistconf.lua \
@@ -98,11 +103,13 @@ dnsdist_SOURCES = \
        dnsdist-dnscrypt.cc \
        dnsdist-dynblocks.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnsdist-idstate.cc \
        dnsdist-lua.hh dnsdist-lua.cc \
        dnsdist-lua-actions.cc \
        dnsdist-lua-bindings.cc \
        dnsdist-lua-bindings-dnsquestion.cc \
        dnsdist-lua-inspection.cc \
+       dnsdist-lua-inspection-ffi.cc dnsdist-lua-inspection-ffi.hh \
        dnsdist-lua-rules.cc \
        dnsdist-lua-vars.cc \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
@@ -165,19 +172,25 @@ dnsdist_LDADD = \
        $(SANITIZER_FLAGS) \
        $(SYSTEMD_LIBS) \
        $(NET_SNMP_LIBS) \
-       $(LIBCAP_LIBS)
+       $(LIBCAP_LIBS) \
+       $(IPCRYPT_LIBS)
 
 if HAVE_RE2
 dnsdist_LDADD += $(RE2_LIBS)
 endif
 
+if HAVE_LIBCRYPTO
+dnsdist_LDADD += $(LIBCRYPTO_LIBS)
+dnsdist_SOURCES += ipcipher.cc ipcipher.hh
+endif
+
 if HAVE_DNS_OVER_TLS
 if HAVE_GNUTLS
 dnsdist_LDADD += -lgnutls
 endif
 
 if HAVE_LIBSSL
-dnsdist_LDADD += $(LIBSSL_LIBS) $(LIBCRYPTO_LIBS)
+dnsdist_LDADD += $(LIBSSL_LIBS)
 endif
 endif
 
@@ -204,20 +217,6 @@ dnsdist.$(OBJEXT): dnsmessage.pb.cc dnstap.pb.cc
 endif
 endif
 
-if HAVE_FREEBSD
-dnsdist_SOURCES += kqueuemplexer.cc
-endif
-
-if HAVE_LINUX
-dnsdist_SOURCES += epollmplexer.cc
-endif
-
-if HAVE_SOLARIS
-dnsdist_SOURCES += \
-        devpollmplexer.cc \
-        portsmplexer.cc
-endif
-
 testrunner_SOURCES = \
        base64.hh \
        dns.hh \
@@ -231,6 +230,7 @@ testrunner_SOURCES = \
        test-dnsdistrules_cc.cc \
        test-dnsparser_cc.cc \
        test-iputils_hh.cc \
+       test-mplexer.cc \
        cachecleaner.hh \
        dnsdist.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
@@ -250,14 +250,35 @@ testrunner_SOURCES = \
        misc.cc misc.hh \
        namespaces.hh \
        pdnsexception.hh \
+       pollmplexer.cc \
        qtype.cc qtype.hh \
        sholder.hh \
        sodcrypto.cc \
        sstuff.hh \
+       statnode.cc statnode.hh \
        threadname.hh threadname.cc \
        testrunner.cc \
        xpf.cc xpf.hh
 
+if HAVE_FREEBSD
+dnsdist_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
+endif
+
+if HAVE_LINUX
+dnsdist_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
+endif
+
+if HAVE_SOLARIS
+dnsdist_SOURCES += \
+        devpollmplexer.cc \
+        portsmplexer.cc
+testrunner_SOURCES += \
+        devpollmplexer.cc \
+        portsmplexer.cc
+endif
+
 testrunner_LDFLAGS = \
        $(AM_LDFLAGS) \
        $(PROGRAM_LDFLAGS) \
index e28decf08c9104febc94fd993911b94bbf436ec1..7d8014c955abcc74a910a64bf9a4a7aceb031c48 100644 (file)
@@ -17,6 +17,9 @@ AC_DEFINE([DNSDIST], [1],
 LT_PREREQ([2.2.2])
 LT_INIT([disable-static])
 
+CFLAGS="-Wall -g -O3 $CFLAGS"
+CXXFLAGS="-Wall -g -O3 $CXXFLAGS"
+
 PDNS_WITH_LIBSODIUM
 PDNS_CHECK_DNSTAP
 PDNS_CHECK_RAGEL([dnslabeltext.cc], [www.dnsdist.org])
@@ -32,8 +35,7 @@ PDNS_CHECK_SECURE_MEMSET
 
 PDNS_WITH_PROTOBUF
 
-boost_required_version=1.42
-BOOST_REQUIRE([$boost_required_version])
+BOOST_REQUIRE([1.42])
 
 PDNS_ENABLE_UNIT_TESTS
 PDNS_WITH_RE2
@@ -48,20 +50,27 @@ AM_CONDITIONAL([HAVE_SYSTEMD], [ test x"$systemd" = "xy" ])
 
 AC_SUBST([YAHTTP_CFLAGS], ['-I$(top_srcdir)/ext/yahttp'])
 AC_SUBST([YAHTTP_LIBS], ['$(top_builddir)/ext/yahttp/yahttp/libyahttp.la'])
+AC_SUBST([IPCRYPT_CFLAGS], ['-I$(top_srcdir)/ext/ipcrypt'])
+AC_SUBST([IPCRYPT_LIBS], ['$(top_builddir)/ext/ipcrypt/libipcrypt.la'])
 
 PDNS_WITH_LUA([mandatory])
+AS_IF([test "x$LUAPC" = "xluajit"], [
+  # export all symbols to be able to use the Lua FFI interface
+  AC_MSG_NOTICE([Adding -rdynamic to export all symbols for the Lua FFI interface])
+  LDFLAGS="$LDFLAGS -rdynamic"
+])
 PDNS_CHECK_LUA_HPP
 
 AM_CONDITIONAL([HAVE_GNUTLS], [false])
 AM_CONDITIONAL([HAVE_LIBSSL], [false])
+
+PDNS_CHECK_LIBCRYPTO
+
 DNSDIST_ENABLE_DNS_OVER_TLS
+
 AS_IF([test "x$enable_dns_over_tls" != "xno"], [
   DNSDIST_WITH_GNUTLS
   DNSDIST_WITH_LIBSSL
-  AS_IF([test "$HAVE_LIBSSL" = "1"], [
-    # we need libcrypto if libssl is enabled
-    PDNS_CHECK_LIBCRYPTO
-  ])
   AS_IF([test "$HAVE_GNUTLS" = "0" -a "$HAVE_LIBSSL" = "0"], [
     AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
   ])
@@ -93,11 +102,12 @@ LDFLAGS="$RELRO_LDFLAGS $LDFLAGS"
 
 CFLAGS="$PIE_CFLAGS $CFLAGS"
 CXXFLAGS="$PIE_CFLAGS $CXXFLAGS"
+
 PROGRAM_LDFLAGS="$PIE_LDFLAGS $PROGRAM_LDFLAGS"
 AC_SUBST([PROGRAM_LDFLAGS])
 
 AC_SUBST([AM_CPPFLAGS],
-  ["AS_ESCAPE([-I$(top_builddir) -I$(top_srcdir)]) -Wall -O3 -pthread $BOOST_CPPFLAGS"]
+  ["AS_ESCAPE([-I$(top_builddir) -I$(top_srcdir)]) $THREADFLAGS $BOOST_CPPFLAGS"]
 )
 
 AC_ARG_VAR(PACKAGEVERSION, [The version used in secpoll queries])
@@ -106,8 +116,9 @@ AS_IF([test "x$PACKAGEVERSION" != "x"],
 )
 
 AC_CONFIG_FILES([Makefile
-       ext/yahttp/Makefile
-       ext/yahttp/yahttp/Makefile])
+        ext/yahttp/Makefile
+        ext/yahttp/yahttp/Makefile
+        ext/ipcrypt/Makefile])
 
 AC_OUTPUT
 
@@ -142,6 +153,10 @@ AS_IF([test "x$systemd" != "xn"],
   [AC_MSG_NOTICE([systemd: yes])],
   [AC_MSG_NOTICE([systemd: no])]
 )
+AS_IF([test "x$LIBCRYPTO_LIBS" != "x"],
+  [AC_MSG_NOTICE([ipcipher: yes])],
+  [AC_MSG_NOTICE([ipcipher: no])]
+)
 AS_IF([test "x$LIBSODIUM_LIBS" != "x"],
   [AC_MSG_NOTICE([libsodium: yes])],
   [AC_MSG_NOTICE([libsodium: no])]
diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc
new file mode 100644 (file)
index 0000000..169ba64
--- /dev/null
@@ -0,0 +1,58 @@
+
+#include "dnsdist.hh"
+
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP)
+{
+  
+  DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, ids.qname.wirelength(), &ids.origDest, &ids.origRemote, dh, bufferSize, responseLen, isTCP, &ids.sentTime.d_start);
+  dr.origFlags = ids.origFlags;
+  dr.ecsAdded = ids.ecsAdded;
+  dr.ednsAdded = ids.ednsAdded;
+  dr.useZeroScope = ids.useZeroScope;
+  dr.packetCache = std::move(ids.packetCache);
+  dr.delayMsec = ids.delayMsec;
+  dr.skipCache = ids.skipCache;
+  dr.cacheKey = ids.cacheKey;
+  dr.cacheKeyNoECS = ids.cacheKeyNoECS;
+  dr.dnssecOK = ids.dnssecOK;
+  dr.tempFailureTTL = ids.tempFailureTTL;
+  dr.qTag = std::move(ids.qTag);
+  dr.subnet = std::move(ids.subnet);
+#ifdef HAVE_PROTOBUF
+  dr.uniqueId = std::move(ids.uniqueId);
+#endif
+  if (ids.dnsCryptQuery) {
+    dr.dnsCryptQuery = std::move(ids.dnsCryptQuery);
+  }
+
+  return dr;  
+}
+
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
+{
+  ids.origRemote = *dq.remote;
+  ids.origDest = *dq.local;
+  ids.sentTime.set(*dq.queryTime);
+  ids.qname = std::move(qname);
+  ids.qtype = dq.qtype;
+  ids.qclass = dq.qclass;
+  ids.delayMsec = dq.delayMsec;
+  ids.tempFailureTTL = dq.tempFailureTTL;
+  ids.origFlags = dq.origFlags;
+  ids.cacheKey = dq.cacheKey;
+  ids.cacheKeyNoECS = dq.cacheKeyNoECS;
+  ids.subnet = dq.subnet;
+  ids.skipCache = dq.skipCache;
+  ids.packetCache = dq.packetCache;
+  ids.ednsAdded = dq.ednsAdded;
+  ids.ecsAdded = dq.ecsAdded;
+  ids.useZeroScope = dq.useZeroScope;
+  ids.qTag = dq.qTag;
+  ids.dnssecOK = dq.dnssecOK;
+  
+  ids.dnsCryptQuery = std::move(dq.dnsCryptQuery);
+  
+#ifdef HAVE_PROTOBUF
+  ids.uniqueId = std::move(dq.uniqueId);
+#endif
+}
diff --git a/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.cc
new file mode 100644 (file)
index 0000000..d05882a
--- /dev/null
@@ -0,0 +1,91 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist.hh"
+#include "dnsdist-dynblocks.hh"
+
+uint64_t dnsdist_ffi_stat_node_get_queries_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.queries;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_noerrors_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.noerrors;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_nxdomains_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.nxdomains;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_servfails_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.servfails;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_drops_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->self.drops;
+}
+
+unsigned int dnsdist_ffi_stat_node_get_labels_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->node.labelsCount;
+}
+
+void dnsdist_ffi_stat_node_get_full_name_raw(const dnsdist_ffi_stat_node_t* node, const char** name, size_t* nameSize)
+{
+  const auto& storage = node->node.fullname;
+  *name = storage.c_str();
+  *nameSize = storage.size();
+}
+
+unsigned int dnsdist_ffi_stat_node_get_children_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->node.children.size();
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_queries_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.queries;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_noerrors_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.noerrors;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_nxdomains_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.nxdomains;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_servfails_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.servfails;
+}
+
+uint64_t dnsdist_ffi_stat_node_get_children_drops_count(const dnsdist_ffi_stat_node_t* node)
+{
+  return node->children.drops;
+}
diff --git a/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh b/pdns/dnsdistdist/dnsdist-lua-inspection-ffi.hh
new file mode 100644 (file)
index 0000000..c4329a2
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+extern "C" {
+  typedef struct dnsdist_ffi_stat_node_t dnsdist_ffi_stat_node_t;
+
+  uint64_t dnsdist_ffi_stat_node_get_queries_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_noerrors_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_nxdomains_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_servfails_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_drops_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  unsigned int dnsdist_ffi_stat_node_get_labels_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  void dnsdist_ffi_stat_node_get_full_name_raw(const dnsdist_ffi_stat_node_t* node, const char** name, size_t* nameSize) __attribute__ ((visibility ("default")));
+
+  unsigned int dnsdist_ffi_stat_node_get_children_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+
+  uint64_t dnsdist_ffi_stat_node_get_children_queries_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_noerrors_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_nxdomains_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_servfails_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+  uint64_t dnsdist_ffi_stat_node_get_children_drops_count(const dnsdist_ffi_stat_node_t* node) __attribute__ ((visibility ("default")));
+}
index 7361a4808c71c4b256e90289bfd98c31f5219de7..bbc4ff2056735d74af71f92d0c8396ad405f67a3 100644 (file)
@@ -863,30 +863,10 @@ public:
       return false;
     }
 
-    uint16_t optStart;
-    size_t optLen = 0;
-    bool last = false;
-    const char * packet = reinterpret_cast<const char*>(dq->dh);
-    std::string packetStr(packet, dq->len);
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return d_extrcode == 0;
-    }
-
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
-      return false;
-    }
-
-    if (optStart < dq->len && packet[optStart] != 0) {
-      // OPT RR Name != '.'
+    EDNS0Record edns0;
+    if (!getEDNS0Record(*dq, edns0)) {
       return false;
     }
-    EDNS0Record edns0;
-    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
-    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
-    memcpy(&edns0, packet + optStart + 5, sizeof edns0);
 
     return d_extrcode == edns0.extRCode;
   }
@@ -907,30 +887,10 @@ public:
   }
   bool matches(const DNSQuestion* dq) const override
   {
-    uint16_t optStart;
-    size_t optLen = 0;
-    bool last = false;
-    const char * packet = reinterpret_cast<const char*>(dq->dh);
-    std::string packetStr(packet, dq->len);
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return false;
-    }
-
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
-      return false;
-    }
-
-    if (optStart < dq->len && packetStr.at(optStart) != 0) {
-      // OPT RR Name != '.'
+    EDNS0Record edns0;
+    if (!getEDNS0Record(*dq, edns0)) {
       return false;
     }
-    EDNS0Record edns0;
-    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
-    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
-    memcpy(&edns0, packet + optStart + 5, sizeof edns0);
 
     return d_version < edns0.version;
   }
@@ -961,8 +921,7 @@ public:
       return false;
     }
 
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
+    if (optLen < optRecordMinimumSize) {
       return false;
     }
 
index 89be18efce60ba1054ec66b7b42667ba334656e8..0655f513e34433731e1f2aac9d854affc7cc12f7 100644 (file)
@@ -4,7 +4,7 @@ Lua actions in rules
 While we can pass every packet through the :func:`blockFilter` functions, it is also possible to configure :program:`dnsdist` to only hand off some packets for Lua inspection. 
 If you think Lua is too slow for your query load, or if you are doing heavy processing in Lua, this may make sense.
 
-To select specific packets for Lua attention, use :func:`addLuaAction` or :func:`addLuaResponseAction`.
+To select specific packets for Lua attention, use :func:`addAction` with :func:`LuaAction`, or :func:`addResponseAction` with :func:`LuaResponseAction`.
 
 A sample configuration could look like this::
 
@@ -17,4 +17,4 @@ A sample configuration could look like this::
     end
   end
 
-  addLuaAction(AllRule(), luarule)
+  addAction(AllRule(), LuaAction(luarule))
index 5b645405d61a9e43d89e83f15a7675bf0b27bba8..ffdbacf6667b55ff02b1b159e66f3e4eac85aab1 100644 (file)
@@ -63,4 +63,4 @@ A working example:
           end
   end
 
-  addLuaAction(AllRule(), pickPool)
+  addAction(AllRule(), LuaAction(pickPool))
index e23bd77c68572f8ef78d307d635fd8b8172ebecf..cf72ca7c51559bd2674a9e190075ba034d54d294 100644 (file)
@@ -15,6 +15,7 @@ First, a few words about :program:`dnsdist` architecture:
 The maximum number of threads in the TCP pool is controlled by the :func:`setMaxTCPClientThreads` directive, and defaults to 10.
 This number can be increased to handle a large number of simultaneous TCP connections.
 If all the TCP threads are busy, new TCP connections are queued while they wait to be picked up.
+Before 1.4.0, a TCP thread could only handle a single incoming connection at a time. Starting with 1.4.0 the handling of TCP connections is now event-based, so a single TCP worker can handle a large number of TCP incoming connections simultaneously.
 
 The maximum number of queued connections can be configured with :func:`setMaxTCPQueuedConnections` and defaults to 1000.
 Any value larger than 0 will cause new connections to be dropped if there are already too many queued.
@@ -37,12 +38,12 @@ When Lua inspection is needed, the best course of action is to restrict the quer
 
 :program:`dnsdist` design choices mean that the processing of UDP queries is done by only one thread per local bind.
 This is great to keep lock contention to a low level, but might not be optimal for setups using a lot of processing power, caused for example by a large number of complicated rules.
-To be able to use more CPU cores for UDP queries processing, it is possible to use the ``reuseport`` parameter of the :func:`addLocal` and :func:`setLocal` directives to be able to add several identical local binds to dnsdist::
+To be able to use more CPU cores for UDP queries processing, it is possible to use the ``reusePort`` parameter of the :func:`addLocal` and :func:`setLocal` directives to be able to add several identical local binds to dnsdist::
 
-  addLocal("192.0.2.1:53", {reuseport=true})
-  addLocal("192.0.2.1:53", {reuseport=true})
-  addLocal("192.0.2.1:53", {reuseport=true})
-  addLocal("192.0.2.1:53", {reuseport=true})
+  addLocal("192.0.2.1:53", {reusePort=true})
+  addLocal("192.0.2.1:53", {reusePort=true})
+  addLocal("192.0.2.1:53", {reusePort=true})
+  addLocal("192.0.2.1:53", {reusePort=true})
 
 :program:`dnsdist` will then add four identical local binds as if they were different IPs or ports, start four threads to handle incoming queries and let the kernel load balance those randomly to the threads, thus using four CPU cores for rules processing.
 Note that this require ``SO_REUSEPORT`` support in the underlying operating system (added for example in Linux 3.9).
index b32b427795e512d2e686c98cf2bfa1ea60cabad5..7ecdfa6545f0055a71a299a6fcd0824e063e8f1f 100644 (file)
@@ -22,6 +22,18 @@ ComboAddresses can be IPv4 or IPv6, and unless you want to know, you don't need
 
     Returns the port number.
 
+  .. method:: ComboAddress:ipdecrypt(key) -> ComboAddress
+
+    Decrypt this IP address as described in https://powerdns.org/ipcipher
+
+    :param string key: A 16 byte key. Note that this can be derived from a passphrase with the standalone function `makeIPCipherKey`
+
+  .. method:: ComboAddress:ipencrypt(key) -> ComboAddress
+
+    Encrypt this IP address as described in https://powerdns.org/ipcipher
+
+    :param string key: A 16 byte key. Note that this can be derived from a passphrase with the standalone function `makeIPCipherKey`
+
   .. method:: ComboAddress:isIPv4() -> bool
 
     Returns true if the address is an IPv4, false otherwise
index 62604d792db5f1450979f662579baa0138986189..ef1f68b4b1ccb89766b216ff96a2dab34975bc44 100644 (file)
@@ -42,6 +42,14 @@ Global configuration
 
   :param str path: The directory to load configuration files from. Each file must end in ``.conf``.
 
+.. function:: setSyslogFacility(facility)
+
+  .. versionadded:: 1.4.0
+
+  Set the syslog logging facility to ``facility``.
+
+  :param int facility: The new facility as a numeric value. Defaults to LOG_DAEMON.
+
 Listen Sockets
 ~~~~~~~~~~~~~~
 
index 6ee99ef8e5480cdc2e12485045f2858fffb72b5a..248f67148ec1160cf79d618b4b3a7a0e64e63ed9 100644 (file)
@@ -16,15 +16,15 @@ OPCode
 
 Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5
 
-.. _DNSQClass:
+.. _DNSClass:
 
-QClass
+DNSClass
 ------
 
-- ``QClass.IN``
-- ``QClass.CHAOS``
-- ``QClass.NONE``
-- ``QClass.ANY``
+- ``DNSClass.IN``
+- ``DNSClass.CHAOS``
+- ``DNSClass.NONE``
+- ``DNSClass.ANY``
 
 Reference: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-2
 
@@ -93,7 +93,7 @@ DNS Section
 DNSAction
 ---------
 
-These constants represent an Action that can be returned from the functions invoked by :func:`addLuaAction`.
+These constants represent an Action that can be returned from :func:`LuaAction` functions.
 
  * ``DNSAction.Allow``: let the query pass, skipping other rules
  * ``DNSAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
@@ -115,7 +115,7 @@ These constants represent an Action that can be returned from the functions invo
 DNSResponseAction
 -----------------
 
-These constants represent an Action that can be returned from the functions invoked by :func:`addLuaResponseAction`.
+These constants represent an Action that can be returned from :func:`LuaResponseAction` functions.
 
  * ``DNSResponseAction.Allow``: let the response pass, skipping other rules
  * ``DNSResponseAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
index c0fc2702d32593d99b83851728d818a23f2671c1..117ac18bb0b49eb806f111fc47270c33cf1b1f5f 100644 (file)
@@ -231,7 +231,7 @@ EDNSOptionView object
 
   An object that represents the values of a single EDNS option received in a query.
 
-  .. attribute:: EDNSOptionView.count -> int
+  .. method:: EDNSOptionView:count()
 
     The number of values for this EDNS option.
 
index 54d56823627cfeeaabaf7c629fe5578b3cffe786..0f63a5145446e0000b1a4a8248a5c6bc87617566 100644 (file)
@@ -3,7 +3,7 @@ Tuning related functions
 
 .. function:: setMaxTCPClientThreads(num)
 
-  Set the maximum of TCP client threads, handling TCP connections
+  Set the maximum of TCP client threads, handling TCP connections. Before 1.4.0 a TCP thread could only handle a single incoming TCP connection at a time, while after 1.4.0 it can handle a larger number of them simultaneously.
 
   :param int num:
 
index 36d8b3a55b5702f60250d3f157c1136dcd9e2c42..1ed837dea24f83cb38cdaf85863c0d9b8aff6eb5 100644 (file)
@@ -151,6 +151,9 @@ Rule Generators
   .. versionchanged:: 1.3.0
     The second argument returned by the ``function`` can be omitted. For earlier releases, simply return an empty string.
 
+  .. deprecated:: 1.4.0
+    Removed in 1.4.0, use :func:`LuaAction` with :func:`addAction` instead.
+
   Invoke a Lua function that accepts a :class:`DNSQuestion`.
   This function works similar to using :func:`LuaAction`.
   The ``function`` should return both a :ref:`DNSAction` and its argument `rule`. The `rule` is used as an argument
@@ -187,6 +190,9 @@ Rule Generators
   .. versionchanged:: 1.3.0
     The second argument returned by the ``function`` can be omitted. For earlier releases, simply return an empty string.
 
+  .. deprecated:: 1.4.0
+    Removed in 1.4.0, use :func:`LuaResponseAction` with :func:`addResponseAction` instead.
+
   Invoke a Lua function that accepts a :class:`DNSResponse`.
   This function works similar to using :func:`LuaResponseAction`.
   The ``function`` should return both a :ref:`DNSResponseAction` and its argument `rule`. The `rule` is used as an argument
@@ -963,6 +969,9 @@ The following actions exist.
   .. versionchanged:: 1.3.0
     ``options`` optional parameter added.
 
+  .. versionchanged:: 1.3.4
+    ``ipEncryptKey`` optional key added to the options table.
+
   Send the content of this query to a remote logger via Protocol Buffer.
   ``alterFunction`` is a callback, receiving a :class:`DNSQuestion` and a :class:`DNSDistProtoBufMessage`, that can be used to modify the Protocol Buffer content, for example for anonymization purposes
 
@@ -973,12 +982,16 @@ The following actions exist.
   Options:
 
   * ``serverID=""``: str - Set the Server Identity field.
+  * ``ipEncryptKey=""``: str - A key, that can be generated via the :ref:`makeIPCipherKey` function, to encrypt the IP address of the requestor for anonymization purposes. The encryption is done using ipcrypt for IPv4 and a 128-bit AES ECB operation for IPv6.
 
 .. function:: RemoteLogResponseAction(remoteLogger[, alterFunction[, includeCNAME [, options]]])
 
   .. versionchanged:: 1.3.0
     ``options`` optional parameter added.
 
+  .. versionchanged:: 1.3.4
+    ``ipEncryptKey`` optional key added to the options table.
+
   Send the content of this response to a remote logger via Protocol Buffer.
   ``alterFunction`` is the same callback that receiving a :class:`DNSQuestion` and a :class:`DNSDistProtoBufMessage`, that can be used to modify the Protocol Buffer content, for example for anonymization purposes
   ``includeCNAME`` indicates whether CNAME records inside the response should be parsed and exported.
@@ -992,6 +1005,7 @@ The following actions exist.
   Options:
 
   * ``serverID=""``: str - Set the Server Identity field.
+  * ``ipEncryptKey=""``: str - A key, that can be generated via the :ref:`makeIPCipherKey` function, to encrypt the IP address of the requestor for anonymization purposes. The encryption is done using ipcrypt for IPv4 and a 128-bit AES ECB operation for IPv6.
 
 .. function:: SetECSAction(v4 [, v6])
 
index a56e9e0fa8d83e663c0dfb6009afa4b92be9d9d9..10f7de4ff1ba49c53c23e7887f9bf64228a9cdc5 100644 (file)
@@ -67,6 +67,18 @@ fd-usage
 --------
 Number of currently used file descriptors.
 
+frontend-noerror
+----------------
+Number of NoError answers sent to clients.
+
+frontend-nxdomain
+-----------------
+Number of NXDomain answers sent to clients.
+
+frontend-servfail
+-----------------
+Number of ServFail answers sent to clients.
+
 latency-avg100
 --------------
 Average response latency in microseconds of the last 100 packets
index 920df10922323ce03e1084b5cca271bb09ca6954..756c08249b587346d9bad99fd3ff66dec11ce82c 100644 (file)
@@ -1,6 +1,11 @@
 Upgrade Guide
 =============
 
+1.3.x to 1.4.0
+--------------
+
+:func:`addLuaAction` and :func:`addLuaResponseAction` have been removed. Instead, use :func:`addAction` with a :func:`LuaAction`, or :func:`addResponseAction` with a :func:`LuaResponseAction`.
+
 1.3.2 to 1.3.3
 --------------
 
diff --git a/pdns/dnsdistdist/ext/ipcrypt/LICENSE b/pdns/dnsdistdist/ext/ipcrypt/LICENSE
new file mode 120000 (symlink)
index 0000000..2e88816
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/LICENSE
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/Makefile.am b/pdns/dnsdistdist/ext/ipcrypt/Makefile.am
new file mode 120000 (symlink)
index 0000000..8111f1d
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/Makefile.am
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c b/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.c
new file mode 120000 (symlink)
index 0000000..e7d1241
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/ipcrypt.c
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h b/pdns/dnsdistdist/ext/ipcrypt/ipcrypt.h
new file mode 120000 (symlink)
index 0000000..ad5c2b8
--- /dev/null
@@ -0,0 +1 @@
+../../../../ext/ipcrypt/ipcrypt.h
\ No newline at end of file
index 641a2797b4a72d8169be528b9973bf09475ca0e7..49304f125c00bb40a93ee8bbf00e9a8b73f2502b 100644 (file)
@@ -194,7 +194,8 @@ $(document).ready(function() {
                      var bouw='<table width="100%"><tr align=right><th>#</th><th align=left>Name</th><th align=left>Address</th><th>Status</th><th>Latency</th><th>Queries</th><th>Drops</th><th>QPS</th><th>Out</th><th>Weight</th><th>Order</th><th align=left>Pools</th></tr>';
                      $.each(data["servers"], function(a,b) {
                          bouw = bouw + ("<tr align=right><td>"+b["id"]+"</td><td align=left>"+b["name"]+"</td><td align=left>"+b["address"]+"</td><td>"+b["state"]+"</td>");
-                         bouw = bouw + ("<td>"+(b["latency"]).toFixed(2)+"</td><td>"+b["queries"]+"</td><td>"+b["reuseds"]+"</td><td>"+(b["qps"]).toFixed(2)+"</td><td>"+b["outstanding"]+"</td>");
+                         var latency = (b["latency"] === null) ? 0.0 : b["latency"];
+                         bouw = bouw + ("<td>"+latency.toFixed(2)+"</td><td>"+b["queries"]+"</td><td>"+b["reuseds"]+"</td><td>"+(b["qps"]).toFixed(2)+"</td><td>"+b["outstanding"]+"</td>");
                          bouw = bouw + ("<td>"+b["weight"]+"</td><td>"+b["order"]+"</td><td align=left>"+b["pools"]+"</td></tr>");
                      }); 
                      bouw = bouw + "</table>";
diff --git a/pdns/dnsdistdist/ipcipher.cc b/pdns/dnsdistdist/ipcipher.cc
new file mode 120000 (symlink)
index 0000000..794b273
--- /dev/null
@@ -0,0 +1 @@
+../ipcipher.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/ipcipher.hh b/pdns/dnsdistdist/ipcipher.hh
new file mode 120000 (symlink)
index 0000000..e8d5917
--- /dev/null
@@ -0,0 +1 @@
+../ipcipher.hh
\ No newline at end of file
index 1f516f14f11db0c5361e8831834338edd5110b36..3ee1f46c35b83d42fbb53d4633b686aad85861c6 100644 (file)
@@ -232,7 +232,7 @@ private:
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
-  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free))
+  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -247,12 +247,62 @@ public:
     if (!SSL_set_fd(d_conn.get(), d_socket)) {
       throw std::runtime_error("Error assigning socket");
     }
+  }
+
+  IOState convertIORequestToIOState(int res) const
+  {
+    int error = SSL_get_error(d_conn.get(), res);
+    if (error == SSL_ERROR_WANT_READ) {
+      return IOState::NeedRead;
+    }
+    else if (error == SSL_ERROR_WANT_WRITE) {
+      return IOState::NeedWrite;
+    }
+    else if (error == SSL_ERROR_SYSCALL) {
+      throw std::runtime_error("Error while processing TLS connection:" + std::string(strerror(errno)));
+    }
+    else {
+      throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error));
+    }
+  }
+
+  void handleIORequest(int res, unsigned int timeout)
+  {
+    auto state = convertIORequestToIOState(res);
+    if (state == IOState::NeedRead) {
+      res = waitForData(d_socket, timeout);
+      if (res <= 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+    }
+    else if (state == IOState::NeedWrite) {
+      res = waitForRWData(d_socket, false, timeout, 0);
+      if (res <= 0) {
+        throw std::runtime_error("Error waiting to write to TLS connection");
+      }
+    }
+  }
+
+  IOState tryHandshake() override
+  {
+    int res = SSL_accept(d_conn.get());
+    if (res == 1) {
+      return IOState::Done;
+    }
+    else if (res < 0) {
+      return convertIORequestToIOState(res);
+    }
 
+    throw std::runtime_error("Error accepting TLS connection");
+  }
+
+  void doHandshake() override
+  {
     int res = 0;
     do {
       res = SSL_accept(d_conn.get());
       if (res < 0) {
-        handleIORequest(res, timeout);
+        handleIORequest(res, d_timeout);
       }
     }
     while (res < 0);
@@ -262,24 +312,40 @@ public:
     }
   }
 
-  void handleIORequest(int res, unsigned int timeout)
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
   {
-    int error = SSL_get_error(d_conn.get(), res);
-    if (error == SSL_ERROR_WANT_READ) {
-      res = waitForData(d_socket, timeout);
-      if (res <= 0) {
-        throw std::runtime_error("Error reading from TLS connection");
+    do {
+      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
       }
-    }
-    else if (error == SSL_ERROR_WANT_WRITE) {
-      res = waitForRWData(d_socket, false, timeout, 0);
-      if (res <= 0) {
-        throw std::runtime_error("Error waiting to write to TLS connection");
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
       }
     }
-    else {
-      throw std::runtime_error("Error writing to TLS connection");
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
+      }
     }
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -300,7 +366,7 @@ public:
         handleIORequest(res, readTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
 
       if (totalTimeout) {
@@ -330,7 +396,7 @@ public:
         handleIORequest(res, writeTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
     }
     while (got < bufferSize);
@@ -346,6 +412,7 @@ public:
 
 private:
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+  unsigned int d_timeout;
 };
 
 class OpenSSLTLSIOCtx: public TLSCtx
@@ -650,7 +717,7 @@ public:
 
   GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
   {
-    unsigned int sslOptions = GNUTLS_SERVER;
+    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
     sslOptions |= GNUTLS_NO_SIGNAL;
 #endif
@@ -685,12 +752,86 @@ public:
     /* timeouts are in milliseconds */
     gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
+  }
 
+  void doHandshake() override
+  {
     int ret = 0;
     do {
       ret = gnutls_handshake(d_conn.get());
+      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    }
+    while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+  }
+
+  IOState tryHandshake() override
+  {
+    int ret = 0;
+
+    do {
+      ret = gnutls_handshake(d_conn.get());
+      if (ret == GNUTLS_E_SUCCESS) {
+        return IOState::Done;
+      }
+      else if (ret == GNUTLS_E_AGAIN) {
+        return IOState::NeedRead;
+      }
+      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    } while (ret == GNUTLS_E_INTERRUPTED);
+
+    throw std::runtime_error("Error accepting a new connection");
+  }
+
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
+  {
+    do {
+      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error writing to TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedWrite;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
     }
-    while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error reading from TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedRead;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
+    }
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -708,13 +849,21 @@ public:
         throw std::runtime_error("Error reading from TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
-          throw std::runtime_error("Error reading from TLS connection");
+          throw std::runtime_error("Error reading from TLS connection:" + std::string(gnutls_strerror(res)));
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          int result = waitForData(d_socket, readTimeout);
+          if (result <= 0) {
+            throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
+          }
+        }
+        else {
+          vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
         }
-        warnlog("Warning, non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
       }
 
       if (totalTimeout) {
@@ -742,13 +891,21 @@ public:
         throw std::runtime_error("Error writing to TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
-          throw std::runtime_error("Error writing to TLS connection");
+          throw std::runtime_error("Error writing to TLS connection: " + std::string(gnutls_strerror(res)));
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          int result = waitForRWData(d_socket, false, writeTimeout, 0);
+          if (result <= 0) {
+            throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
+          }
+        }
+        else {
+          vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
         }
-        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
       }
     }
     while (got < bufferSize);
index 560c222b2958fb77203f4885a3f950bd95640bce..b191a268a5e094b35867ca456727f16ab374cba7 100644 (file)
@@ -10,6 +10,7 @@
 
 Rings g_rings;
 GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
 
 BOOST_AUTO_TEST_SUITE(dnsdistdynblocks_hh)
 
index a402864068bb8e7b6e8b698b950cc5ad66b67db0..357c902db74837e5f9313f997f965fd2ed6ab425 100644 (file)
@@ -22,7 +22,8 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   uint16_t qclass = QClass::IN;
   ComboAddress lc("127.0.0.1:53");
   ComboAddress rem("192.0.2.1:42");
-  struct dnsheader* dh = nullptr;
+  struct dnsheader dh;
+  memset(&dh, 0, sizeof(dh));
   size_t bufferSize = 0;
   size_t queryLen = 0;
   bool isTcp = false;
@@ -32,7 +33,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   /* the internal QPS limiter does not use the real time */
   gettime(&expiredTime);
 
-  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, bufferSize, queryLen, isTcp, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime);
 
   for (size_t idx = 0; idx < maxQPS; idx++) {
     /* let's use different source ports, it shouldn't matter */
@@ -45,12 +46,16 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   BOOST_CHECK_EQUAL(rule.matches(&dq), true);
   BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1);
 
+  /* remove all entries that have not been updated since 'now' + 1,
+     so all of them */
   expiredTime.tv_sec += 1;
   rule.cleanup(expiredTime);
 
   /* we should have been cleaned up */
   BOOST_CHECK_EQUAL(rule.getEntriesCount(), 0);
 
+  struct timespec beginInsertionTime;
+  gettime(&beginInsertionTime);
   /* we should not be blocked anymore */
   BOOST_CHECK_EQUAL(rule.matches(&dq), false);
   /* and we be back */
@@ -58,21 +63,21 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
 
 
   /* Let's insert a lot of different sources now */
-  struct timespec insertionTime;
-  gettime(&insertionTime);
   for (size_t idxByte3 = 0; idxByte3 < 256; idxByte3++) {
     for (size_t idxByte4 = 0; idxByte4 < 256; idxByte4++) {
       rem = ComboAddress("10.0." + std::to_string(idxByte3) + "." + std::to_string(idxByte4));
       BOOST_CHECK_EQUAL(rule.matches(&dq), false);
     }
   }
+  struct timespec endInsertionTime;
+  gettime(&endInsertionTime);
 
   /* don't forget the existing entry */
   size_t total = 1 + 256 * 256;
   BOOST_CHECK_EQUAL(rule.getEntriesCount(), total);
 
   /* make sure all entries are still valid */
-  struct timespec notExpiredTime = insertionTime;
+  struct timespec notExpiredTime = beginInsertionTime;
   notExpiredTime.tv_sec -= 1;
 
   size_t scanned = 0;
@@ -83,7 +88,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   BOOST_CHECK_EQUAL(rule.getEntriesCount(), total);
 
   /* make sure all entries are _not_ valid anymore */
-  expiredTime = insertionTime;
+  expiredTime = endInsertionTime;
   expiredTime.tv_sec += 1;
 
   removed = rule.cleanup(expiredTime, &scanned);
diff --git a/pdns/dnsdistdist/test-mplexer.cc b/pdns/dnsdistdist/test-mplexer.cc
new file mode 120000 (symlink)
index 0000000..f406267
--- /dev/null
@@ -0,0 +1 @@
+../test-mplexer.cc
\ No newline at end of file
index bd8e23dad5ecce0e5ae2a355f305249e54079acc..3ae3ea7ee3886043f7fdf6c37509ea630cd09bb7 100644 (file)
@@ -40,19 +40,40 @@ otherwise, obfuscate the response IP address
 #include "statbag.hh"
 #include "dnspcap.hh"
 #include "iputils.hh"
-
+#include "ipcipher.hh"
 #include "namespaces.hh"
+#include <boost/program_options.hpp>
+#include "base64.hh"
 
 StatBag S;
 
+namespace po = boost::program_options;
+po::variables_map g_vm;
+
+
 class IPObfuscator
 {
 public:
-  IPObfuscator() : d_romap(d_ipmap), d_ro6map(d_ip6map), d_counter(0)
+  virtual uint32_t obf4(uint32_t orig)=0;
+  virtual struct in6_addr obf6(const struct in6_addr& orig)=0;
+};
+
+class IPSeqObfuscator : public IPObfuscator
+{
+public:
+  IPSeqObfuscator() : d_romap(d_ipmap), d_ro6map(d_ip6map), d_counter(0)
   {
   }
 
-  uint32_t obf4(uint32_t orig)
+  ~IPSeqObfuscator()
+  {}
+
+  static std::unique_ptr<IPObfuscator> make()
+  {
+    return std::unique_ptr<IPObfuscator>(new IPSeqObfuscator());
+  }
+
+  uint32_t obf4(uint32_t orig) override
   {
     if(d_romap.count(orig))
       return d_ipmap[orig];
@@ -61,7 +82,7 @@ public:
     }
   }
 
-  struct in6_addr obf6(const struct in6_addr& orig)
+  struct in6_addr obf6(const struct in6_addr& orig) override
   {
     uint32_t val;
     if(d_ro6map.count(orig))
@@ -95,6 +116,48 @@ private:
   uint32_t d_counter;
 };
 
+class IPCipherObfuscator : public IPObfuscator
+{
+public:
+  IPCipherObfuscator(const std::string& key, bool decrypt)  : d_key(key), d_decrypt(decrypt)
+  {
+    if(d_key.size()!=16) {
+      throw std::runtime_error("IPCipher requires a 128 bit key");
+    }
+  }
+
+  ~IPCipherObfuscator()
+  {}
+  static std::unique_ptr<IPObfuscator> make(std::string key, bool decrypt)
+  {
+    return std::unique_ptr<IPObfuscator>(new IPCipherObfuscator(key, decrypt));
+  }
+
+  uint32_t obf4(uint32_t orig) override
+  {
+    ComboAddress ca;
+    ca.sin4.sin_family = AF_INET;
+    ca.sin4.sin_addr.s_addr = orig;
+    ca = d_decrypt ? decryptCA(ca, d_key) : encryptCA(ca, d_key);
+    return ca.sin4.sin_addr.s_addr;
+
+  }
+
+  struct in6_addr obf6(const struct in6_addr& orig) override
+  {
+    ComboAddress ca;
+    ca.sin4.sin_family = AF_INET6;
+    ca.sin6.sin6_addr = orig;
+    ca = d_decrypt ? decryptCA(ca, d_key) : encryptCA(ca, d_key);
+    return ca.sin6.sin6_addr;
+  }
+
+private:
+  std::string d_key;
+  bool d_decrypt;
+};
+
+
 void usage() {
   cerr<<"Syntax: dnswasher INFILE1 [INFILE2..] OUTFILE"<<endl;
 }
@@ -102,53 +165,98 @@ void usage() {
 int main(int argc, char** argv)
 try
 {
-  for (int i = 1; i < argc; i++) {
-    if ((string) argv[i] == "--help") {
-      usage();
-      exit(EXIT_SUCCESS);
-    }
+  po::options_description desc("Allowed options");
+  desc.add_options()
+    ("help,h", "produce help message")
+    ("version", "show version number")
+    ("key,k", po::value<string>(), "base64 encoded 128 bit key for ipcipher")
+    ("passphrase,p", po::value<string>(), "passphrase for ipcipher (will be used to derive key)")
+    ("decrypt,d", "decrypt IP addresses with ipcipher");
 
-    if ((string) argv[i] == "--version") {
-      cerr<<"dnswasher "<<VERSION<<endl;
-      exit(EXIT_SUCCESS);
-    }
+  po::options_description alloptions;
+  po::options_description hidden("hidden options");
+  hidden.add_options()
+    ("infiles", po::value<vector<string>>(), "PCAP source file(s)")
+    ("outfile", po::value<string>(), "outfile");
+
+
+  alloptions.add(desc).add(hidden);
+  po::positional_options_description p;
+  p.add("infiles", 1);
+  p.add("outfile", 1);
+
+  po::store(po::command_line_parser(argc, argv).options(alloptions).positional(p).run(), g_vm);
+  po::notify(g_vm);
+
+  if(g_vm.count("help")) {
+    usage();
+    cout<<desc<<endl;
+    exit(EXIT_SUCCESS);
+  }
+
+  if(g_vm.count("version")) {
+    cout<<"dnswasher "<<VERSION<<endl;
+    exit(EXIT_SUCCESS);
   }
 
-  if(argc < 3) {
+  if(!g_vm.count("outfile")) {
+    cout<<"Missing outfile"<<endl;
     usage();
-    exit(1);
+    exit(EXIT_FAILURE);
   }
 
-  PcapPacketWriter pw(argv[argc-1]);
-  IPObfuscator ipo;
-  // 0          1   2   3    - argc == 4
-  // dnswasher in1 in2 out
-  for(int n=1; n < argc -1; ++n) {
-    PcapPacketReader pr(argv[n]);
+  bool doDecrypt = g_vm.count("decrypt");
+
+  PcapPacketWriter pw(g_vm["outfile"].as<string>());
+  std::unique_ptr<IPObfuscator> ipo;
+
+  if(!g_vm.count("key") && !g_vm.count("passphrase"))
+    ipo = IPSeqObfuscator::make();
+  else if(g_vm.count("key") && !g_vm.count("passphrase")) {
+    string key;
+    if(B64Decode(g_vm["key"].as<string>(), key) < 0) {
+      cerr<<"Invalidly encoded base64 key provided"<<endl;
+      exit(EXIT_FAILURE);
+    }
+    ipo = IPCipherObfuscator::make(key, doDecrypt);
+  }
+  else if(!g_vm.count("key") && g_vm.count("passphrase")) {
+    string key = makeIPCipherKey(g_vm["passphrase"].as<string>());
+
+    ipo = IPCipherObfuscator::make(key, doDecrypt);
+  }
+  else {
+    cerr<<"Can't specify both 'key' and 'passphrase'"<<endl;
+    exit(EXIT_FAILURE);
+  }
+
+  for(const auto& inf : g_vm["infiles"].as<vector<string>>()) {
+    PcapPacketReader pr(inf);
     pw.setPPR(pr);
 
     while(pr.getUDPPacket()) {
       if(ntohs(pr.d_udp->uh_dport)==53 || (ntohs(pr.d_udp->uh_sport)==53 && pr.d_len > sizeof(dnsheader))) {
         dnsheader* dh=(dnsheader*)pr.d_payload;
-        
+
         if (pr.d_ip->ip_v == 4){
           uint32_t *src=(uint32_t*)&pr.d_ip->ip_src;
           uint32_t *dst=(uint32_t*)&pr.d_ip->ip_dst;
-          
+
           if(dh->qr)
-            *dst=htonl(ipo.obf4(*dst));
+            *dst=ipo->obf4(*dst);
           else
-            *src=htonl(ipo.obf4(*src));
-          
+            *src=ipo->obf4(*src);
+
           pr.d_ip->ip_sum=0;
         } else if (pr.d_ip->ip_v == 6) {
           auto src=&pr.d_ip6->ip6_src;
           auto dst=&pr.d_ip6->ip6_dst;
-          
+
           if(dh->qr)
-            *dst=ipo.obf6(*dst);
+            *dst=ipo->obf6(*dst);
           else
-            *src=ipo.obf6(*src);
+            *src=ipo->obf6(*src);
+          // IPv6 checksum does not cover source/destination addresses
         }
         pw.write();
       }
index 0d5bffe0d93040f5be3cf539bfcefe592f40a984..cce1b74769347ac2d40cfb7d73dec83e69ba29a4 100644 (file)
@@ -70,6 +70,13 @@ void dolog(std::ostream& os, const char* s, T value, Args... args)
 extern bool g_verbose;
 extern bool g_syslog;
 
+inline void setSyslogFacility(int facility)
+{
+  /* we always call openlog() right away at startup */
+  closelog();
+  openlog("dnsdist", LOG_PID|LOG_NDELAY, facility);
+}
+
 template<typename... Args>
 void genlog(int level, const char* s, Args... args)
 {
index 97d7e82ff6a887a0799e29e6a5873fa4614bead7..433687d21456f8c24e7227683bb39bcdf719a04d 100644 (file)
@@ -45,7 +45,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -94,9 +94,9 @@ EpollFDMultiplexer::EpollFDMultiplexer() : d_eevents(new epoll_event[s_maxevents
     
 }
 
-void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+void EpollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct epoll_event eevent;
   
@@ -156,13 +156,13 @@ int EpollFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(d_eevents[n].data.fd);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't refind ourselves as writable!
     }
     d_iter=d_writeCallbacks.find(d_eevents[n].data.fd);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
   d_inrun=false;
diff --git a/pdns/ipcipher.cc b/pdns/ipcipher.cc
new file mode 100644 (file)
index 0000000..57a2aa3
--- /dev/null
@@ -0,0 +1,99 @@
+#include "ipcipher.hh"
+#include "ext/ipcrypt/ipcrypt.h"
+#include <openssl/aes.h>
+#include <openssl/evp.h>
+
+/*
+int PKCS5_PBKDF2_HMAC_SHA1(const char *pass, int passlen,
+                           const unsigned char *salt, int saltlen, int iter,
+                           int keylen, unsigned char *out);
+*/
+std::string makeIPCipherKey(const std::string& password)
+{
+  static const char salt[]="ipcipheripcipher";
+  unsigned char out[16];
+
+  PKCS5_PBKDF2_HMAC_SHA1(password.c_str(), password.size(), (const unsigned char*)salt, sizeof(salt)-1, 50000, sizeof(out), out);
+
+  return std::string((const char*)out, (const char*)out + sizeof(out));
+}
+
+static ComboAddress encryptCA4(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  // always returns 0, has no failure mode
+  ipcrypt_encrypt(      (unsigned char*)&ret.sin4.sin_addr.s_addr,
+                 (const unsigned char*)  &ca.sin4.sin_addr.s_addr,
+                 (const unsigned char*)key.c_str());
+  return ret;
+}
+
+static ComboAddress decryptCA4(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  // always returns 0, has no failure mode
+  ipcrypt_decrypt(      (unsigned char*)&ret.sin4.sin_addr.s_addr,
+                 (const unsigned char*)  &ca.sin4.sin_addr.s_addr,
+                 (const unsigned char*)key.c_str());
+  return ret;
+}
+
+
+static ComboAddress encryptCA6(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+
+  AES_KEY wctx;
+  AES_set_encrypt_key((const unsigned char*)key.c_str(), 128, &wctx);
+  AES_encrypt((const unsigned char*)&ca.sin6.sin6_addr.s6_addr,
+              (unsigned char*)&ret.sin6.sin6_addr.s6_addr, &wctx);
+
+  return ret;
+}
+
+static ComboAddress decryptCA6(const ComboAddress& ca, const std::string &key)
+{
+  if(key.size() != 16)
+    throw std::runtime_error("Need 128 bits of key for ipcrypt");
+
+  ComboAddress ret=ca;
+  AES_KEY wctx;
+  AES_set_decrypt_key((const unsigned char*)key.c_str(), 128, &wctx);
+  AES_decrypt((const unsigned char*)&ca.sin6.sin6_addr.s6_addr,
+              (unsigned char*)&ret.sin6.sin6_addr.s6_addr, &wctx);
+
+  return ret;
+}
+
+
+ComboAddress encryptCA(const ComboAddress& ca, const std::string& key)
+{
+  if(ca.sin4.sin_family == AF_INET)
+    return encryptCA4(ca, key);
+  else if(ca.sin4.sin_family == AF_INET6)
+    return encryptCA6(ca, key);
+  else
+    throw std::runtime_error("ipcrypt can't encrypt non-IP addresses");
+}
+
+ComboAddress decryptCA(const ComboAddress& ca, const std::string& key)
+{
+  if(ca.sin4.sin_family == AF_INET)
+    return decryptCA4(ca, key);
+  else if(ca.sin4.sin_family == AF_INET6)
+    return decryptCA6(ca, key);
+  else
+    throw std::runtime_error("ipcrypt can't decrypt non-IP addresses");
+
+}
diff --git a/pdns/ipcipher.hh b/pdns/ipcipher.hh
new file mode 100644 (file)
index 0000000..cbb932d
--- /dev/null
@@ -0,0 +1,9 @@
+#pragma once
+#include "iputils.hh"
+#include <string>
+
+// see https://powerdns.org/ipcipher
+
+ComboAddress encryptCA(const ComboAddress& ca, const std::string& key);
+ComboAddress decryptCA(const ComboAddress& ca, const std::string& key);
+std::string makeIPCipherKey(const std::string& password);
index daf5997d6f465acafe15de29f361caa02ee31a97..84425e66cb08514ecbacfdb92e716f266b77d91f 100644 (file)
@@ -47,7 +47,7 @@ int SSocket(int family, int type, int flags)
 
 int SConnect(int sockfd, const ComboAddress& remote)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     RuntimeError(boost::format("connecting socket to %s: %s") % remote.toStringWithPort() % strerror(savederrno));
@@ -57,12 +57,12 @@ int SConnect(int sockfd, const ComboAddress& remote)
 
 int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
 {
-  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  int ret = connect(sockfd, reinterpret_cast<const struct sockaddr*>(&remote), remote.getSocklen());
   if(ret < 0) {
     int savederrno = errno;
     if (savederrno == EINPROGRESS) {
       if (timeout <= 0) {
-        return ret;
+        return savederrno;
       }
 
       /* we wait until the connection has been established */
@@ -97,7 +97,7 @@ int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout)
     }
   }
 
-  return ret;
+  return 0;
 }
 
 int SBind(int sockfd, const ComboAddress& local)
@@ -269,40 +269,108 @@ void ComboAddress::truncate(unsigned int bits) noexcept
   *place &= (~((1<<bitsleft)-1));
 }
 
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf)
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags)
 {
+  int remainingTime = totalTimeout;
+  time_t start = 0;
+  if (totalTimeout) {
+    start = time(nullptr);
+  }
+
   struct msghdr msgh;
   struct iovec iov;
   char cbuf[256];
+
+  /* Set up iov and msgh structures. */
+  memset(&msgh, 0, sizeof(struct msghdr));
+  msgh.msg_control = nullptr;
+  msgh.msg_controllen = 0;
+  if (dest) {
+    msgh.msg_name = reinterpret_cast<void*>(const_cast<ComboAddress*>(dest));
+    msgh.msg_namelen = dest->getSocklen();
+  }
+  else {
+    msgh.msg_name = nullptr;
+    msgh.msg_namelen = 0;
+  }
+
+  msgh.msg_flags = 0;
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  iov.iov_base = reinterpret_cast<void*>(const_cast<char*>(buffer));
+  iov.iov_len = len;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  msgh.msg_flags = 0;
+
+  size_t sent = 0;
   bool firstTry = true;
-  fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast<char*>(buffer), len, &dest);
-  addCMsgSrcAddr(&msgh, cbuf, &local, localItf);
 
   do {
-    ssize_t written = sendmsg(fd, &msgh, 0);
 
-    if (written > 0)
-      return written;
+#ifdef MSG_FASTOPEN
+    if (flags & MSG_FASTOPEN && firstTry == false) {
+      flags &= ~MSG_FASTOPEN;
+    }
+#endif /* MSG_FASTOPEN */
+
+    ssize_t res = sendmsg(fd, &msgh, flags);
 
-    if (errno == EAGAIN) {
-      if (firstTry) {
-        int res = waitForRWData(fd, false, timeout, 0);
-        if (res > 0) {
-          /* there is room available */
-          firstTry = false;
+    if (res > 0) {
+      size_t written = static_cast<size_t>(res);
+      sent += written;
+
+      if (sent == len) {
+        return sent;
+      }
+
+      /* partial write */
+      iov.iov_len -= written;
+      iov.iov_base = reinterpret_cast<void*>(reinterpret_cast<char*>(iov.iov_base) + written);
+      written = 0;
+    }
+    else if (res == -1) {
+      if (errno == EINTR) {
+        continue;
+      }
+      else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS || errno == ENOTCONN) {
+        /* EINPROGRESS might happen with non blocking socket,
+           especially with TCP Fast Open */
+        if (totalTimeout <= 0 && idleTimeout <= 0) {
+          return sent;
+        }
+
+        if (firstTry) {
+          int res = waitForRWData(fd, false, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime, 0);
+          if (res > 0) {
+            /* there is room available */
+            firstTry = false;
+          }
+          else if (res == 0) {
+            throw runtime_error("Timeout while waiting to write data");
+          } else {
+            throw runtime_error("Error while waiting for room to write data");
+          }
         }
-        else if (res == 0) {
+        else {
           throw runtime_error("Timeout while waiting to write data");
-        } else {
-          throw runtime_error("Error while waiting for room to write data");
         }
       }
       else {
-        throw runtime_error("Timeout while waiting to write data");
+        unixDie("failed in sendMsgWithTimeout");
       }
     }
-    else {
-      unixDie("failed in write2WithTimeout");
+    if (totalTimeout) {
+      time_t now = time(nullptr);
+      int elapsed = now - start;
+      if (elapsed >= remainingTime) {
+        throw runtime_error("Timeout while sending data");
+      }
+      start = now;
+      remainingTime -= elapsed;
     }
   }
   while (firstTry);
index 490e45ac436665556dd77f7e545187cb810ca9a2..498e03c9243df8334ce40ccec24e0741320387f9 100644 (file)
@@ -1062,7 +1062,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat
 bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv);
 void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, char* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
 ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to);
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf);
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 bool sendSizeAndMsgWithTimeout(int sock, uint16_t bufferLen, const char* buffer, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 /* requires a non-blocking, connected TCP socket */
 bool isTCPSocketUsable(int sock);
index 01d2505d8ffd885481c24b1af15e1cdc37d26e54..485e720bcd8d2700be491a236f88feb1ed52b321 100644 (file)
 
 string doGetStats();
 
-IXFRDistWebServer::IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl) :
+IXFRDistWebServer::IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl, const string &loglevel) :
   d_ws(std::unique_ptr<WebServer>(new WebServer(listenAddress.toString(), listenAddress.getPort())))
 {
   d_ws->setACL(acl);
+  d_ws->setLogLevel(loglevel);
   d_ws->registerWebHandler("/metrics", boost::bind(&IXFRDistWebServer::getMetrics, this, _1, _2));
   d_ws->bind();
 }
index 580e976344332b3d0a577bf89c36ccc8541de766..e1c1734cc39b14e28d546cb8b2f7095a07091381 100644 (file)
@@ -26,7 +26,7 @@
 class IXFRDistWebServer
 {
   public:
-    explicit IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl);
+    explicit IXFRDistWebServer(const ComboAddress &listenAddress, const NetmaskGroup &acl, const string &loglevel);
     void go();
 
   private:
index f49e0fae56228e7b77dd6c547589e048680eca81..d23d554cbe245d1121880319a5cb95d7e709c59c 100644 (file)
@@ -343,7 +343,7 @@ void updateThread(const string& workdir, const uint16_t& keep, const uint16_t& a
 
       // TODO Keep track of 'down' masters
       set<ComboAddress>::const_iterator it(domainConfig.second.masters.begin());
-      std::advance(it, random() % domainConfig.second.masters.size());
+      std::advance(it, dns_random(domainConfig.second.masters.size()));
       ComboAddress master = *it;
 
       string dir = workdir + "/" + domain.toString();
@@ -1138,6 +1138,16 @@ static bool parseAndCheckConfig(const string& configpath, YAML::Node& config) {
     }
   }
 
+  if (config["webserver-loglevel"]) {
+    try {
+      config["webserver-loglevel"].as<string>();
+    }
+    catch (const runtime_error &e) {
+      g_log<<Logger::Error<<"Unable to read 'webserver-loglevel' value: "<<e.what()<<endl;
+      retval = false;
+    }
+  }
+
   return retval;
 }
 
@@ -1280,8 +1290,18 @@ int main(int argc, char** argv) {
       }
     }
 
+    string loglevel = "normal";
+    if (config["webserver-loglevel"]) {
+      loglevel = config["webserver-loglevel"].as<string>();
+    }
+
     // Launch the webserver!
-    std::thread(&IXFRDistWebServer::go, IXFRDistWebServer(config["webserver-address"].as<ComboAddress>(), wsACL)).detach();
+    try {
+      std::thread(&IXFRDistWebServer::go, IXFRDistWebServer(config["webserver-address"].as<ComboAddress>(), wsACL, loglevel)).detach();
+    } catch (const PDNSException &e) {
+      g_log<<Logger::Error<<"Unable to start webserver: "<<e.reason<<endl;
+      had_error = true;
+    }
   }
 
   int newuid = 0;
index b995cb5ef1c45aac8dc060c2cf83e9e57719ec15..976b7f23c6224762d428575ca972eefda90468f5 100644 (file)
@@ -88,6 +88,12 @@ webserver-acl:
   - 127.0.0.0/8
   - ::1/128
 
+# How much the webserver should log: 'none', 'normal' or 'detailed'
+# With 'none', nothing is logged except for errors
+# With 'normal' (the default), one line per request is logged in the style of the common log format
+# with 'detailed', the full requests and responses (including headers) are logged
+webserver-loglevel: normal
+
 # The domains to redistribute, the 'master' and 'domains' keys are mandatory.
 # When no port is specified, 53 is used. When specifying ports for IPv6, use the
 # "bracket" notation:
index 44d3f467354a84374b1a7ac4c20a33eb862f93e9..0c193d1dd44af73b8299799e103096e41d68652f 100644 (file)
@@ -47,7 +47,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
   string getName() const override
   {
@@ -80,9 +80,9 @@ KqueueFDMultiplexer::KqueueFDMultiplexer() : d_kevents(new struct kevent[s_maxev
     throw FDMultiplexerException("Setting up kqueue: "+stringerror());
 }
 
-void KqueueFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void KqueueFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   struct kevent kqevent;
   EV_SET(&kqevent, fd, (&cbmap == &d_readCallbacks) ? EVFILT_READ : EVFILT_WRITE, EV_ADD, 0,0,0);
@@ -144,14 +144,14 @@ int KqueueFDMultiplexer::run(struct timeval* now, int timeout)
   for(int n=0; n < ret; ++n) {
     d_iter=d_readCallbacks.find(d_kevents[n].ident);
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       continue; // so we don't find ourselves as writable again
     }
 
     d_iter=d_writeCallbacks.find(d_kevents[n].ident);
 
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
     }
   }
 
index cef57880230e12f3f3b338d3a76653ed582fdc69..3ebed5dcd35e2034a17a65342877c24254a11a00 100644 (file)
@@ -8,6 +8,7 @@
 #include "ueberbackend.hh"
 #include <boost/format.hpp>
 #include "dnsrecords.hh"
+#include "dns_random.hh"
 
 #include "../modules/geoipbackend/geoipinterface.hh" // only for the enum
 
@@ -254,7 +255,7 @@ static ComboAddress pickrandom(const vector<ComboAddress>& ips)
   if (ips.empty()) {
     throw std::invalid_argument("The IP list cannot be empty");
   }
-  return ips[random() % ips.size()];
+  return ips[dns_random(ips.size())];
 }
 
 static ComboAddress hashed(const ComboAddress& who, const vector<ComboAddress>& ips)
@@ -278,7 +279,7 @@ static ComboAddress pickwrandom(const vector<pair<int,ComboAddress> >& wips)
     sum += i.first;
     pick.push_back({sum, i.second});
   }
-  int r = random() % sum;
+  int r = dns_random(sum);
   auto p = upper_bound(pick.begin(), pick.end(),r, [](int r, const decltype(pick)::value_type& a) { return  r < a.first;});
   return p->second;
 }
@@ -384,7 +385,7 @@ static ComboAddress pickclosest(const ComboAddress& bestwho, const vector<ComboA
     //          cout<<"    distance: "<<sqrt(dist2) * 40000.0/360<<" km"<<endl; // length of a degree
     ranked[dist2].push_back(c);
   }
-  return ranked.begin()->second[random() % ranked.begin()->second.size()];
+  return ranked.begin()->second[dns_random(ranked.begin()->second.size())];
 }
 
 static std::vector<DNSZoneRecord> lookup(const DNSName& name, uint16_t qtype, int zoneid)
@@ -856,7 +857,7 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
         for(const auto& nmpair : netmasks) {
           Netmask nm(nmpair.second);
           if(nm.match(bestwho)) {
-            return destinations[random() % destinations.size()].second;
+            return destinations[dns_random(destinations.size())].second;
           }
         }
       }
index 80a727a376859b8c3eb3756ceea00e0089db7461..4ad23d22460b0688afdbd6283887225ec45fc657 100644 (file)
@@ -322,10 +322,10 @@ void RecursorLua4::postPrepareContext()
   d_lw->registerMember("size", &EDNSOptionViewValue::size);
   d_lw->registerFunction<std::string(EDNSOptionViewValue::*)()>("getContent", [](const EDNSOptionViewValue& value) { return std::string(value.content, value.size); });
   d_lw->registerFunction<size_t(EDNSOptionView::*)()>("count", [](const EDNSOptionView& option) { return option.values.size(); });
-  d_lw->registerFunction<std::vector<std::pair<int, string>>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
-      std::vector<std::pair<int, string> > values;
+  d_lw->registerFunction<std::vector<string>(EDNSOptionView::*)()>("getValues", [] (const EDNSOptionView& option) {
+      std::vector<string> values;
       for (const auto& value : option.values) {
-        values.push_back(std::make_pair(values.size(), std::string(value.content, value.size)));
+        values.push_back(std::string(value.content, value.size));
       }
       return values;
     });
index d70143d46be8d9d7e1aabc51aa558cb7b6b691e7..927651c332c5324e43b470715e29438d63922052 100644 (file)
 #include <boost/shared_array.hpp>
 #include <boost/tuple/tuple.hpp>
 #include <boost/tuple/tuple_comparison.hpp>
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/hashed_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
 #include <vector>
 #include <map>
 #include <stdexcept>
 #include <string>
 #include <sys/time.h>
 
+using namespace ::boost::multi_index;
+
 class FDMultiplexerException : public std::runtime_error
 {
 public:
@@ -51,14 +57,15 @@ class FDMultiplexer
 {
 public:
   typedef boost::any funcparam_t;
+  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
 protected:
 
-  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
   struct Callback
   {
     callbackfunc_t d_callback;
-    funcparam_t d_parameter;
+    mutable funcparam_t d_parameter;
     struct timeval d_ttd;
+    int d_fd;
   };
 
 public:
@@ -77,15 +84,15 @@ public:
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) = 0;
 
   //! Add an fd to the read watch list - currently an fd can only be on one list at a time!
-  virtual void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t())
+  virtual void addReadFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
   {
-    this->addFD(d_readCallbacks, fd, toDo, parameter);
+    this->addFD(d_readCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Add an fd to the write watch list - currently an fd can only be on one list at a time!
-  virtual void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t())
+  virtual void addWriteFD(int fd, callbackfunc_t toDo, const funcparam_t& parameter=funcparam_t(), const struct timeval* ttd=nullptr)
   {
-    this->addFD(d_writeCallbacks, fd, toDo, parameter);
+    this->addFD(d_writeCallbacks, fd, toDo, parameter, ttd);
   }
 
   //! Remove an fd from the read watch list. You can't call this function on an fd that is closed already!
@@ -104,25 +111,43 @@ public:
 
   virtual void setReadTTD(int fd, struct timeval tv, int timeout)
   {
-    if(!d_readCallbacks.count(fd))
+    const auto& it = d_readCallbacks.find(fd);
+    if (it == d_readCallbacks.end()) {
       throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
+    }
+
+    auto newEntry = *it;
     tv.tv_sec += timeout;
-    d_readCallbacks[fd].d_ttd=tv;
+    newEntry.d_ttd = tv;
+    d_readCallbacks.replace(it, newEntry);
   }
 
-  virtual funcparam_t& getReadParameter(int fd) 
+  virtual void setWriteTTD(int fd, struct timeval tv, int timeout)
   {
-    if(!d_readCallbacks.count(fd))
-      throw FDMultiplexerException("attempt to look up data in multiplexer for unlisted fd "+std::to_string(fd));
-    return d_readCallbacks[fd].d_parameter;
+    const auto& it = d_writeCallbacks.find(fd);
+    if (it == d_writeCallbacks.end()) {
+      throw FDMultiplexerException("attempt to timestamp fd not in the multiplexer");
+    }
+
+    auto newEntry = *it;
+    tv.tv_sec += timeout;
+    newEntry.d_ttd = tv;
+    d_writeCallbacks.replace(it, newEntry);
   }
 
-  virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv)
+  virtual std::vector<std::pair<int, funcparam_t> > getTimeouts(const struct timeval& tv, bool writes=false)
   {
     std::vector<std::pair<int, funcparam_t> > ret;
-    for(callbackmap_t::iterator i=d_readCallbacks.begin(); i!=d_readCallbacks.end(); ++i)
-      if(i->second.d_ttd.tv_sec && boost::tie(tv.tv_sec, tv.tv_usec) > boost::tie(i->second.d_ttd.tv_sec, i->second.d_ttd.tv_usec)) 
-        ret.push_back(std::make_pair(i->first, i->second.d_parameter));
+    const auto tied = boost::tie(tv.tv_sec, tv.tv_usec);
+    auto& idx = writes ? d_writeCallbacks.get<TTDOrderedTag>() : d_readCallbacks.get<TTDOrderedTag>();
+
+    for (auto it = idx.begin(); it != idx.end(); ++it) {
+      if (it->d_ttd.tv_sec == 0 || tied <= boost::tie(it->d_ttd.tv_sec, it->d_ttd.tv_usec)) {
+        break;
+      }
+      ret.push_back(std::make_pair(it->d_fd, it->d_parameter));
+    }
+
     return ret;
   }
 
@@ -137,31 +162,77 @@ public:
   
   virtual std::string getName() const = 0;
 
+  size_t getWatchedFDCount(bool writeFDs) const
+  {
+    return writeFDs ? d_writeCallbacks.size() : d_readCallbacks.size();
+  }
+
 protected:
-  typedef std::map<int, Callback> callbackmap_t;
+  struct FDBasedTag {};
+  struct TTDOrderedTag {};
+  struct ttd_compare
+  {
+    /* we want a 0 TTD (no timeout) to come _after_ everything else */
+    bool operator() (const struct timeval& lhs, const struct timeval& rhs) const
+    {
+      /* special treatment if at least one of the TTD is 0,
+         normal comparison otherwise */
+      if (lhs.tv_sec == 0 && rhs.tv_sec == 0) {
+        return false;
+      }
+      if (lhs.tv_sec == 0 && rhs.tv_sec != 0) {
+        return false;
+      }
+      if (lhs.tv_sec != 0 && rhs.tv_sec == 0) {
+        return true;
+      }
+
+      return std::tie(lhs.tv_sec, lhs.tv_usec) < std::tie(rhs.tv_sec, rhs.tv_usec);
+    }
+  };
+
+  typedef multi_index_container<
+    Callback,
+    indexed_by <
+                hashed_unique<tag<FDBasedTag>,
+                              member<Callback,int,&Callback::d_fd>
+                              >,
+                ordered_non_unique<tag<TTDOrderedTag>,
+                                   member<Callback,struct timeval,&Callback::d_ttd>,
+                                   ttd_compare
+                                   >
+               >
+  > callbackmap_t;
+
   callbackmap_t d_readCallbacks, d_writeCallbacks;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)=0;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)=0;
   virtual void removeFD(callbackmap_t& cbmap, int fd)=0;
   bool d_inrun;
   callbackmap_t::iterator d_iter;
 
-  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter)
+  void accountingAddFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr)
   {
     Callback cb;
+    cb.d_fd = fd;
     cb.d_callback=toDo;
     cb.d_parameter=parameter;
     memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
-  
-    if(cbmap.count(fd))
+    if (ttd) {
+      cb.d_ttd = *ttd;
+    }
+
+    auto pair = cbmap.insert(cb);
+    if (!pair.second) {
       throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
-    cbmap[fd]=cb;
+    }
   }
 
   void accountingRemoveFD(callbackmap_t& cbmap, int fd) 
   {
-    if(!cbmap.erase(fd)) 
+    if(!cbmap.erase(fd)) {
       throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
+    }
   }
 };
 
index 16027b61011f488f82a238e08b9635fa74727a20..2d01f12f272d8088eecfdadfdbf484ab72cd943b 100644 (file)
@@ -274,6 +274,7 @@ template<class Key, class Val>void MTasker<Key,Val>::makeThread(tfunc_t *start,
                                             &uc->uc_stack[uc->uc_stack.size()-1]);
 #endif /* PDNS_USE_VALGRIND */
 
+  ++d_threadsCount;
   auto& thread = d_threads[d_maxtid];
   auto mt = this;
   thread.start = [start, val, mt]() {
@@ -316,6 +317,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::schedule(struct timeval*  n
   }
   if(!d_zombiesQueue.empty()) {
     d_threads.erase(d_zombiesQueue.front());
+    --d_threadsCount;
     d_zombiesQueue.pop();
     return true;
   }
@@ -357,7 +359,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::schedule(struct timeval*  n
  */
 template<class Key, class Val>bool MTasker<Key,Val>::noProcesses() const
 {
-  return d_threads.empty();
+  return d_threadsCount == 0;
 }
 
 //! returns the number of processes running
@@ -366,7 +368,7 @@ template<class Key, class Val>bool MTasker<Key,Val>::noProcesses() const
  */
 template<class Key, class Val>unsigned int MTasker<Key,Val>::numProcesses() const
 {
-  return d_threads.size();
+  return d_threadsCount;
 }
 
 //! gives access to the list of Events threads are waiting for
index 87bc6723cff1dc760c513e880dd2b9f46499f027..0365e756a2690c25ba9141ba7960bc866f532872 100644 (file)
@@ -68,9 +68,10 @@ private:
 
   typedef std::map<int, ThreadInfo> mthreads_t;
   mthreads_t d_threads;
+  size_t d_stacksize;
+  size_t d_threadsCount;
   int d_tid;
   int d_maxtid;
-  size_t d_stacksize;
 
   EventVal d_waitval;
   enum waitstatusenum {Error=-1,TimeOut=0,Answer} d_waitstatus;
@@ -110,7 +111,7 @@ public:
       This limit applies solely to the stack, the heap is not limited in any way. If threads need to allocate a lot of data,
       the use of new/delete is suggested. 
    */
-  MTasker(size_t stacksize=16*8192) : d_tid(0), d_maxtid(0), d_stacksize(stacksize), d_waitstatus(Error)
+  MTasker(size_t stacksize=16*8192) : d_stacksize(stacksize), d_threadsCount(0), d_tid(0), d_maxtid(0), d_waitstatus(Error)
   {
     initMainStackBounds();
 
index 92cb57ae7bc7b4623570a4e10dcd82513f53360e..1ec0c11d64ee7baed789f5b747c5eca93f69fce6 100644 (file)
@@ -24,6 +24,7 @@
 #endif
 #include <bitset>
 #include "dnsparser.hh"
+#include "dns_random.hh"
 #include "iputils.hh"
 #include <boost/program_options.hpp>
 
@@ -59,6 +60,8 @@ int main(int argc, char** argv)
 try
 {
   set<ComboAddress> addrs;
+  ::arg().set("rng")="auto";
+  ::arg().set("entropy-source")="/dev/urandom";
 
   for(int n=1 ; n < argc; ++n) {
     if ((string) argv[n] == "--help") {
@@ -112,7 +115,7 @@ try
     }
     vector<uint8_t> outpacket;
     DNSPacketWriter pw(outpacket, DNSName(argv[2]), QType::SOA, 1, Opcode::Notify);
-    pw.getHeader()->id = random();
+    pw.getHeader()->id = dns_random(UINT16_MAX);
 
     if(send(sock, &outpacket[0], outpacket.size(), 0) < 0) {
       cerr<<"Unable to send notify to "<<addr.toStringWithPort()<<": "+stringerror()<<endl;
index 72888474af8ceaf1711e7fbd5214d1453730ebfb..e7189ece004c6515c115c929d42838ab004db1eb 100644 (file)
@@ -1003,7 +1003,7 @@ void PacketHandler::makeNOError(DNSPacket* p, DNSPacket* r, const DNSName& targe
   if(d_dk.isSecuredZone(sd.qname))
     addNSECX(p, r, target, wildcard, sd.qname, mode);
 
-  S.ringAccount("noerror-queries",p->qdomain.toLogString()+"/"+p->qtype.getName());
+  S.ringAccount("noerror-queries", p->qdomain, p->qtype);
 }
 
 
@@ -1561,7 +1561,7 @@ DNSPacket *PacketHandler::doQuestion(DNSPacket *p)
     r=p->replyPacket(); // generate an empty reply packet
     r->setRcode(RCode::ServFail);
     S.inc("servfail-packets");
-    S.ringAccount("servfail-queries",p->qdomain.toLogString());
+    S.ringAccount("servfail-queries", p->qdomain, p->qtype);
   }
   catch(PDNSException &e) {
     g_log<<Logger::Error<<"Backend reported permanent error which prevented lookup ("+e.reason+"), aborting"<<endl;
@@ -1573,7 +1573,7 @@ DNSPacket *PacketHandler::doQuestion(DNSPacket *p)
     r=p->replyPacket(); // generate an empty reply packet
     r->setRcode(RCode::ServFail);
     S.inc("servfail-packets");
-    S.ringAccount("servfail-queries",p->qdomain.toLogString());
+    S.ringAccount("servfail-queries", p->qdomain, p->qtype);
   }
   return r; 
 
index dcbbf88579b8b93cda311296be150be2b7e389fe..311d861f5416b53de9fc12c5b5057dda3ff0eb90 100644 (file)
@@ -161,6 +161,8 @@ struct RecThreadInfo
   deferredAdd_t deferredAdds;
   struct ThreadPipeSet pipes;
   std::thread thread;
+  MT_t* mt{nullptr};
+  uint64_t numberOfDistributedQueries{0};
   /* handle the web server, carbon, statistics and the control channel */
   bool isHandler{false};
   /* accept incoming queries (and distributes them to the workers if pdns-distributes-queries is set) */
@@ -226,6 +228,7 @@ static std::set<uint16_t> s_avoidUdpSourcePorts;
 #endif
 static uint16_t s_minUdpSourcePort;
 static uint16_t s_maxUdpSourcePort;
+static double s_balancingFactor;
 
 RecursorControlChannel s_rcc; // only active in the handler thread
 RecursorStats g_stats;
@@ -1626,8 +1629,10 @@ static void startDoResolve(void *p)
         else {
           dc->d_tcpConnection->state=TCPConnection::BYTE0;
           Utility::gettimeofday(&g_now, 0); // needs to be updated
-          t_fdm->addReadFD(dc->d_socket, handleRunningTCPQuestion, dc->d_tcpConnection);
-          t_fdm->setReadTTD(dc->d_socket, g_now, g_tcpTimeout);
+          struct timeval ttd = g_now;
+          ttd.tv_sec += g_tcpTimeout;
+
+          t_fdm->addReadFD(dc->d_socket, handleRunningTCPQuestion, dc->d_tcpConnection, &ttd);
         }
       }
     }
@@ -2047,11 +2052,11 @@ static void handleNewTCPQuestion(int fd, FDMultiplexer::funcparam_t& )
     std::shared_ptr<TCPConnection> tc = std::make_shared<TCPConnection>(newsock, addr);
     tc->state=TCPConnection::BYTE0;
 
-    t_fdm->addReadFD(tc->getFD(), handleRunningTCPQuestion, tc);
+    struct timeval ttd;
+    Utility::gettimeofday(&ttd, 0);
+    ttd.tv_sec += g_tcpTimeout;
 
-    struct timeval now;
-    Utility::gettimeofday(&now, 0);
-    t_fdm->setReadTTD(tc->getFD(), now, g_tcpTimeout);
+    t_fdm->addReadFD(tc->getFD(), handleRunningTCPQuestion, tc, &ttd);
   }
 }
 
@@ -2385,6 +2390,7 @@ static void handleNewUDPQuestion(int fd, FDMultiplexer::funcparam_t& var)
             distributeAsyncFunction(data, boost::bind(doProcessUDPQuestion, data, fromaddr, dest, tv, fd));
           }
           else {
+            ++s_threadInfos[t_id].numberOfDistributedQueries;
             doProcessUDPQuestion(data, fromaddr, dest, tv, fd);
           }
         }
@@ -2633,6 +2639,14 @@ static void doStats(void)
     g_log<<Logger::Notice<<"stats: " <<  broadcastAccFunction<uint64_t>(pleaseGetPacketCacheSize) <<
     " packet cache entries, "<<(int)(100.0*broadcastAccFunction<uint64_t>(pleaseGetPacketCacheHits)/SyncRes::s_queries) << "% packet cache hits"<<endl;
 
+    size_t idx = 0;
+    for (const auto& threadInfo : s_threadInfos) {
+      if(threadInfo.isWorker) {
+        g_log<<Logger::Notice<<"Thread "<<idx<<" has been distributed "<<threadInfo.numberOfDistributedQueries<<" queries"<<endl;
+        ++idx;
+      }
+    }
+
     time_t now = time(0);
     if(lastOutputTime && lastQueryCount && now != lastOutputTime) {
       g_log<<Logger::Notice<<"stats: "<< (SyncRes::s_queries - lastQueryCount) / (now - lastOutputTime) <<" qps (average over "<< (now - lastOutputTime) << " seconds)"<<endl;
@@ -2818,7 +2832,7 @@ void broadcastFunction(const pipefunc_t& func)
 
 static bool trySendingQueryToWorker(unsigned int target, ThreadMSG* tmsg)
 {
-  const auto& targetInfo = s_threadInfos[target];
+  auto& targetInfo = s_threadInfos[target];
   if(!targetInfo.isWorker) {
     g_log<<Logger::Error<<"distributeAsyncFunction() tried to assign a query to a non-worker thread"<<endl;
     exit(1);
@@ -2843,9 +2857,52 @@ static bool trySendingQueryToWorker(unsigned int target, ThreadMSG* tmsg)
     }
   }
 
+  ++targetInfo.numberOfDistributedQueries;
+
   return true;
 }
 
+static unsigned int getWorkerLoad(size_t workerIdx)
+{
+  const auto mt = s_threadInfos[/* skip handler */ 1 + g_numDistributorThreads + workerIdx].mt;
+  if (mt != nullptr) {
+    return mt->numProcesses();
+  }
+  return 0;
+}
+
+static unsigned int selectWorker(unsigned int hash)
+{
+  if (s_balancingFactor == 0) {
+    return /* skip handler */ 1 + g_numDistributorThreads + (hash % g_numWorkerThreads);
+  }
+
+  /* we start with one, representing the query we are currently handling */
+  double currentLoad = 1;
+  std::vector<unsigned int> load(g_numWorkerThreads);
+  for (size_t idx = 0; idx < g_numWorkerThreads; idx++) {
+    load[idx] = getWorkerLoad(idx);
+    currentLoad += load[idx];
+    // cerr<<"load for worker "<<idx<<" is "<<load[idx]<<endl;
+  }
+
+  double targetLoad = (currentLoad / g_numWorkerThreads) * s_balancingFactor;
+  // cerr<<"total load is "<<currentLoad<<", number of workers is "<<g_numWorkerThreads<<", target load is "<<targetLoad<<endl;
+
+  unsigned int worker = hash % g_numWorkerThreads;
+  /* at least one server has to be at or below the average load */
+  if (load[worker] > targetLoad) {
+    ++g_stats.rebalancedQueries;
+    do {
+      // cerr<<"worker "<<worker<<" is above the target load, selecting another one"<<endl;
+      worker = (worker + 1) % g_numWorkerThreads;
+    }
+    while(load[worker] > targetLoad);
+  }
+
+  return /* skip handler */ 1 + g_numDistributorThreads + worker;
+}
+
 // This function is only called by the distributor threads, when pdns-distributes-queries is set
 void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
 {
@@ -2855,7 +2912,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
   }
 
   unsigned int hash = hashQuestion(packet.c_str(), packet.length(), g_disthashseed);
-  unsigned int target = /* skip handler */ 1 + g_numDistributorThreads + (hash % g_numWorkerThreads);
+  unsigned int target = selectWorker(hash);
 
   ThreadMSG* tmsg = new ThreadMSG();
   tmsg->func = func;
@@ -2976,30 +3033,31 @@ template string broadcastAccFunction(const boost::function<string*()>& fun); //
 template uint64_t broadcastAccFunction(const boost::function<uint64_t*()>& fun); // explicit instantiation
 template vector<ComboAddress> broadcastAccFunction(const boost::function<vector<ComboAddress> *()>& fun); // explicit instantiation
 template vector<pair<DNSName,uint16_t> > broadcastAccFunction(const boost::function<vector<pair<DNSName, uint16_t> > *()>& fun); // explicit instantiation
+template ThreadTimes broadcastAccFunction(const boost::function<ThreadTimes*()>& fun);
 
 static void handleRCC(int fd, FDMultiplexer::funcparam_t& var)
 {
-  string remote;
-  string msg=s_rcc.recv(&remote);
-  RecursorControlParser rcp;
-  RecursorControlParser::func_t* command;
+  try {
+    string remote;
+    string msg=s_rcc.recv(&remote);
+    RecursorControlParser rcp;
+    RecursorControlParser::func_t* command;
 
-  string answer=rcp.getAnswer(msg, &command);
+    string answer=rcp.getAnswer(msg, &command);
 
-  // If we are inside a chroot, we need to strip
-  if (!arg()["chroot"].empty()) {
-    size_t len = arg()["chroot"].length();
-    remote = remote.substr(len);
-  }
+    // If we are inside a chroot, we need to strip
+    if (!arg()["chroot"].empty()) {
+      size_t len = arg()["chroot"].length();
+      remote = remote.substr(len);
+    }
 
-  try {
     s_rcc.send(answer, &remote);
     command();
   }
-  catch(std::exception& e) {
+  catch(const std::exception& e) {
     g_log<<Logger::Error<<"Error dealing with control socket request: "<<e.what()<<endl;
   }
-  catch(PDNSException& ae) {
+  catch(const PDNSException& ae) {
     g_log<<Logger::Error<<"Error dealing with control socket request: "<<ae.reason<<endl;
   }
 }
@@ -3630,6 +3688,7 @@ static int serviceMain(int argc, char*argv[])
   }
 
   SyncRes::s_minimumTTL = ::arg().asNum("minimum-ttl-override");
+  SyncRes::s_minimumECSTTL = ::arg().asNum("ecs-minimum-ttl-override");
 
   SyncRes::s_nopacketcache = ::arg().mustDo("disable-packetcache");
 
@@ -3653,6 +3712,10 @@ static int serviceMain(int argc, char*argv[])
 
   SyncRes::s_ecsipv4limit = ::arg().asNum("ecs-ipv4-bits");
   SyncRes::s_ecsipv6limit = ::arg().asNum("ecs-ipv6-bits");
+  SyncRes::clearECSStats();
+  SyncRes::s_ecsipv4cachelimit = ::arg().asNum("ecs-ipv4-cache-bits");
+  SyncRes::s_ecsipv6cachelimit = ::arg().asNum("ecs-ipv6-cache-bits");
+  SyncRes::s_ecscachelimitttl = ::arg().asNum("ecs-cache-limit-ttl");
 
   if (!::arg().isEmpty("ecs-scope-zero-address")) {
     ComboAddress scopeZero(::arg()["ecs-scope-zero-address"]);
@@ -3716,6 +3779,12 @@ static int serviceMain(int argc, char*argv[])
 
   g_statisticsInterval = ::arg().asNum("statistics-interval");
 
+  s_balancingFactor = ::arg().asDouble("distribution-load-factor");
+  if (s_balancingFactor != 0.0 && s_balancingFactor < 1.0) {
+    s_balancingFactor = 0.0;
+    g_log<<Logger::Warning<<"Asked to run with a distribution-load-factor below 1.0, disabling it instead"<<endl;
+  }
+
 #ifdef SO_REUSEPORT
   g_reusePort = ::arg().mustDo("reuseport");
 #endif
@@ -3857,6 +3926,11 @@ static int serviceMain(int argc, char*argv[])
   g_tcpMaxQueriesPerConn=::arg().asNum("max-tcp-queries-per-connection");
   s_maxUDPQueriesPerRound=::arg().asNum("max-udp-queries-per-round");
 
+  blacklistStats(StatComponent::API, ::arg()["stats-api-blacklist"]);
+  blacklistStats(StatComponent::Carbon, ::arg()["stats-carbon-blacklist"]);
+  blacklistStats(StatComponent::RecControl, ::arg()["stats-rec-control-blacklist"]);
+  blacklistStats(StatComponent::SNMP, ::arg()["stats-snmp-blacklist"]);
+
   if (::arg().mustDo("snmp-agent")) {
     g_snmpAgent = std::make_shared<RecursorSNMPAgent>("recursor", ::arg()["snmp-master-socket"]);
     g_snmpAgent->run();
@@ -4010,6 +4084,7 @@ try
   }
 
   MT=std::unique_ptr<MTasker<PacketID,string> >(new MTasker<PacketID,string>(::arg().asNum("stack-size")));
+  threadInfo.mt = MT.get();
 
 #ifdef HAVE_PROTOBUF
   /* start protobuf export threads if needed */
@@ -4159,6 +4234,7 @@ int main(int argc, char **argv)
   g_argc = argc;
   g_argv = argv;
   g_stats.startupTime=time(0);
+  Utility::srandom();
   versionSetProduct(ProductRecursor);
   reportBasicTypes();
   reportOtherTypes();
@@ -4197,6 +4273,7 @@ int main(int argc, char **argv)
     ::arg().set("webserver-port", "Port of webserver to listen on") = "8082";
     ::arg().set("webserver-password", "Password required for accessing the webserver") = "";
     ::arg().set("webserver-allow-from","Webserver access is only allowed from these subnets")="127.0.0.1,::1";
+    ::arg().set("webserver-loglevel", "Amount of logging in the webserver (none, normal, detailed)") = "normal";
     ::arg().set("carbon-ourname", "If set, overrides our reported hostname for carbon stats")="";
     ::arg().set("carbon-server", "If set, send metrics in carbon (graphite) format to this server IP address")="";
     ::arg().set("carbon-interval", "Number of seconds between carbon (graphite) updates")="30";
@@ -4254,7 +4331,11 @@ int main(int argc, char **argv)
     ::arg().set("latency-statistic-size","Number of latency values to calculate the qa-latency average")="10000";
     ::arg().setSwitch( "disable-packetcache", "Disable packetcache" )= "no";
     ::arg().set("ecs-ipv4-bits", "Number of bits of IPv4 address to pass for EDNS Client Subnet")="24";
+    ::arg().set("ecs-ipv4-cache-bits", "Maximum number of bits of IPv4 mask to cache ECS response")="24";
     ::arg().set("ecs-ipv6-bits", "Number of bits of IPv6 address to pass for EDNS Client Subnet")="56";
+    ::arg().set("ecs-ipv6-cache-bits", "Maximum number of bits of IPv6 mask to cache ECS response")="56";
+    ::arg().set("ecs-minimum-ttl-override", "Set under adverse conditions, a minimum TTL for records in ECS-specific answers")="0";
+    ::arg().set("ecs-cache-limit-ttl", "Minimum TTL to cache ECS response")="0";
     ::arg().set("edns-subnet-whitelist", "List of netmasks and domains that we should enable EDNS subnet for")="";
     ::arg().set("ecs-add-for", "List of client netmasks for which EDNS Client Subnet will be added")="0.0.0.0/0, ::/0, " LOCAL_NETS_INVERSE;
     ::arg().set("ecs-scope-zero-address", "Address to send to whitelisted authoritative servers for incoming queries with ECS prefix-length source of 0")="";
@@ -4280,6 +4361,18 @@ int main(int argc, char **argv)
     ::arg().setSwitch("snmp-agent", "If set, register as an SNMP agent")="no";
     ::arg().set("snmp-master-socket", "If set and snmp-agent is set, the socket to use to register to the SNMP master")="";
 
+    std::string defaultBlacklistedStats = "cache-bytes, packetcache-bytes, special-memory-usage";
+    for (size_t idx = 0; idx < 32; idx++) {
+      defaultBlacklistedStats += ", ecs-v4-response-bits-" + std::to_string(idx + 1);
+    }
+    for (size_t idx = 0; idx < 128; idx++) {
+      defaultBlacklistedStats += ", ecs-v6-response-bits-" + std::to_string(idx + 1);
+    }
+    ::arg().set("stats-api-blacklist", "List of statistics that are disabled when retrieving the complete list of statistics via the API")=defaultBlacklistedStats;
+    ::arg().set("stats-carbon-blacklist", "List of statistics that are prevented from being exported via Carbon")=defaultBlacklistedStats;
+    ::arg().set("stats-rec-control-blacklist", "List of statistics that are prevented from being exported via rec_control get-all")=defaultBlacklistedStats;
+    ::arg().set("stats-snmp-blacklist", "List of statistics that are prevented from being exported via SNMP")=defaultBlacklistedStats;
+
     ::arg().set("tcp-fast-open", "Enable TCP Fast Open support on the listening sockets, using the supplied numerical value as the queue size")="0";
     ::arg().set("nsec3-max-iterations", "Maximum number of iterations allowed for an NSEC3 record")="2500";
 
@@ -4295,6 +4388,7 @@ int main(int argc, char **argv)
     ::arg().set("udp-source-port-avoid", "List of comma separated UDP port number to avoid")="11211";
     ::arg().set("rng", "Specify random number generator to use. Valid values are auto,sodium,openssl,getrandom,arc4random,urandom.")="auto";
     ::arg().set("public-suffix-list-file", "Path to the Public Suffix List file, if any")="";
+    ::arg().set("distribution-load-factor", "The load factor used when PowerDNS is distributing queries to worker threads")="0.0";
 #ifdef NOD_ENABLED
     ::arg().set("new-domain-tracking", "Track newly observed domains (i.e. never seen before).")="no";
     ::arg().set("new-domain-log", "Log newly observed domains.")="yes";
index d7e40a4587ba51b3d9aa353671c81d8b6fabc69a..9556faf19f5626028f9c8b7a73e865807509c518 100644 (file)
@@ -20,6 +20,7 @@
 #include "zoneparser-tng.hh"
 #include "signingpipe.hh"
 #include "dns_random.hh"
+#include "ipcipher.hh"
 #include <fstream>
 #include <termios.h>            //termios, TCSANOW, ECHO, ICANON
 #include "opensslsigners.hh"
@@ -76,7 +77,7 @@ void loadMainConfig(const std::string& configdir)
     exit(0);
   }
 
-  if(::arg()["config-name"]!="") 
+  if(::arg()["config-name"]!="")
     s_programname+="-"+::arg()["config-name"];
 
   string configname=::arg()["config-dir"]+"/"+s_programname+".conf";
@@ -112,22 +113,22 @@ void loadMainConfig(const std::string& configdir)
   g_log.toConsole(Logger::Error);   // so we print any errors
   BackendMakers().launch(::arg()["launch"]); // vrooooom!
   if(::arg().asNum("loglevel") >= 3) // so you can't kill our errors
-    g_log.toConsole((Logger::Urgency)::arg().asNum("loglevel"));  
+    g_log.toConsole((Logger::Urgency)::arg().asNum("loglevel"));
 
   //cerr<<"Backend: "<<::arg()["launch"]<<", '" << ::arg()["gmysql-dbname"] <<"'" <<endl;
 
   S.declare("qsize-q","Number of questions waiting for database attention");
-          
+
   ::arg().set("max-cache-entries", "Maximum number of cache entries")="1000000";
-  ::arg().set("cache-ttl","Seconds to store packets in the PacketCache")="20";              
+  ::arg().set("cache-ttl","Seconds to store packets in the PacketCache")="20";
   ::arg().set("negquery-cache-ttl","Seconds to store negative query results in the QueryCache")="60";
-  ::arg().set("query-cache-ttl","Seconds to store query results in the QueryCache")="20";              
+  ::arg().set("query-cache-ttl","Seconds to store query results in the QueryCache")="20";
   ::arg().set("default-soa-name","name to insert in the SOA record if none set in the backend")="a.misconfigured.powerdns.server";
   ::arg().set("default-soa-mail","mail address to insert in the SOA record if none set in the backend")="";
   ::arg().set("soa-refresh-default","Default SOA refresh")="10800";
   ::arg().set("soa-retry-default","Default SOA retry")="3600";
   ::arg().set("soa-expire-default","Default SOA expire")="604800";
-  ::arg().set("soa-minimum-ttl","Default SOA minimum ttl")="3600";    
+  ::arg().set("soa-minimum-ttl","Default SOA minimum ttl")="3600";
   ::arg().set("chroot","Switch to this chroot jail")="";
   ::arg().set("dnssec-key-cache-ttl","Seconds to cache DNSSEC keys from the database")="30";
   ::arg().set("domain-metadata-cache-ttl","Seconds to cache domain metadata from the database")="60";
@@ -200,7 +201,7 @@ void dbBench(const std::string& fname)
   dt.set();
   unsigned int hits=0, misses=0;
   for(; n < 10000; ++n) {
-    DNSName domain(domains[random() % domains.size()]);
+    DNSName domain(domains[dns_random(domains.size())]);
     B.lookup(QType(QType::NS), domain);
     while(B.get(rr)) {
       hits++;
@@ -322,7 +323,7 @@ int checkZone(DNSSECKeeper &dk, UeberBackend &B, const DNSName& zone, const vect
       records.push_back(drr);
     }
   }
-  else 
+  else
     records=*suppliedrecords;
 
   for(auto &rr : records) { // we modify this
@@ -665,7 +666,7 @@ int increaseSerial(const DNSName& zone, DNSSECKeeper &dk)
     NSEC3PARAMRecordContent ns3pr;
     bool narrow;
     bool haveNSEC3=dk.getNSEC3PARAM(zone, &ns3pr, &narrow);
-  
+
     DNSName ordername;
     if(haveNSEC3) {
       if(!narrow)
@@ -716,7 +717,7 @@ void listKey(DomainInfo const &di, DNSSECKeeper& dk, bool printHeader = true) {
     spacelen = (std::to_string(key.first.getKey()->getBits()).length() >= 8) ? 1 : 8 - std::to_string(key.first.getKey()->getBits()).length();
     if (key.first.getKey()->getBits() < 1) {
       cout<<"invalid "<<endl;
-      continue; 
+      continue;
     } else {
       cout<<key.first.getKey()->getBits()<<string(spacelen, ' ');
     }
@@ -778,7 +779,7 @@ int listKeys(const string &zname, DNSSECKeeper& dk){
 int listZone(const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -790,9 +791,9 @@ int listZone(const DNSName &zone) {
   
   while(di.backend->get(rr)) {
     if(rr.qtype.getCode()) {
-      if ( (rr.qtype.getCode() == QType::NS || rr.qtype.getCode() == QType::SRV || rr.qtype.getCode() == QType::MX || rr.qtype.getCode() == QType::CNAME) && !rr.content.empty() && rr.content[rr.content.size()-1] != '.') 
+      if ( (rr.qtype.getCode() == QType::NS || rr.qtype.getCode() == QType::SRV || rr.qtype.getCode() == QType::MX || rr.qtype.getCode() == QType::CNAME) && !rr.content.empty() && rr.content[rr.content.size()-1] != '.')
        rr.content.append(1, '.');
-       
+
       cout<<rr.qname<<"\t"<<rr.ttl<<"\tIN\t"<<rr.qtype.getName()<<"\t"<<rr.content<<"\n";
     }
   }
@@ -801,8 +802,8 @@ int listZone(const DNSName &zone) {
 }
 
 // lovingly copied from http://stackoverflow.com/questions/1798511/how-to-avoid-press-enter-with-any-getchar
-int read1char(){   
-    int c;   
+int read1char(){
+    int c;
     static struct termios oldt, newt;
 
     /*tcgetattr gets the parameters of the current terminal
@@ -814,7 +815,7 @@ int read1char(){
 
     /*ICANON normally takes care that one line at a time will be processed
     that means it will return if it sees a "\n" or an EOF or an EOL*/
-    newt.c_lflag &= ~(ICANON);          
+    newt.c_lflag &= ~(ICANON);
 
     /*Those new settings will be set to STDIN
     TCSANOW tells tcsetattr to change attributes immediately. */
@@ -831,7 +832,7 @@ int read1char(){
 int clearZone(DNSSECKeeper& dk, const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -847,7 +848,7 @@ int clearZone(DNSSECKeeper& dk, const DNSName &zone) {
 int editZone(DNSSECKeeper& dk, const DNSName &zone) {
   UeberBackend B;
   DomainInfo di;
-  
+
   if (! B.getDomainInfo(zone, di)) {
     cerr<<"Domain '"<<zone<<"' not found!"<<endl;
     return EXIT_FAILURE;
@@ -937,7 +938,7 @@ int editZone(DNSSECKeeper& dk, const DNSName &zone) {
     cerr<<"\n";
     if(c!='a')
       post.clear();
-    if(c=='e') 
+    if(c=='e')
       goto editMore;
     else if(c=='r')
       goto editAgain;
@@ -1002,6 +1003,20 @@ int editZone(DNSSECKeeper& dk, const DNSName &zone) {
   return EXIT_SUCCESS;
 }
 
+static int xcryptIP(const std::string& cmd, const std::string& ip, const std::string& rkey)
+{
+
+  ComboAddress ca(ip), ret;
+
+  if(cmd=="ipencrypt")
+    ret = encryptCA(ca, rkey);
+  else
+    ret = decryptCA(ca, rkey);
+
+  cout<<ret.toString()<<endl;
+  return EXIT_SUCCESS;
+}
+
 
 int loadZone(DNSName zone, const string& fname) {
   UeberBackend B;
@@ -1013,7 +1028,7 @@ int loadZone(DNSName zone, const string& fname) {
   else {
     cerr<<"Creating '"<<zone<<"'"<<endl;
     B.createDomain(zone);
-    
+
     if(!B.getDomainInfo(zone, di)) {
       cerr<<"Domain '"<<zone<<"' was not created - perhaps backend ("<<::arg()["launch"]<<") does not support storing new zones."<<endl;
       return EXIT_FAILURE;
@@ -1021,13 +1036,13 @@ int loadZone(DNSName zone, const string& fname) {
   }
   DNSBackend* db = di.backend;
   ZoneParserTNG zpt(fname, zone);
-  
+
   DNSResourceRecord rr;
   if(!db->startTransaction(zone, di.id)) {
     cerr<<"Unable to start transaction for load of zone '"<<zone<<"'"<<endl;
     return EXIT_FAILURE;
   }
-  rr.domain_id=di.id;  
+  rr.domain_id=di.id;
   bool haveSOA = false;
   while(zpt.get(rr)) {
     if(!rr.qname.isPartOf(zone) && rr.qname!=zone) {
@@ -1065,7 +1080,7 @@ int createZone(const DNSName &zone, const DNSName& nsname) {
   rr.auth = 1;
   rr.ttl = ::arg().asNum("default-ttl");
   rr.qtype = "SOA";
-  
+
   string soa = (boost::format("%s %s 1")
                 % (nsname.empty() ? ::arg()["default-soa-name"] : nsname.toString())
                 % (::arg().isEmpty("default-soa-mail") ? (DNSName("hostmaster.") + zone).toString() : ::arg()["default-soa-mail"])
@@ -1082,7 +1097,7 @@ int createZone(const DNSName &zone, const DNSName& nsname) {
     rr.content=nsname.toStringNoDot();
     di.backend->feedRecord(rr, DNSName());
   }
-  
+
   di.backend->commitTransaction();
 
   return EXIT_SUCCESS;
@@ -1144,7 +1159,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
   DNSName name;
   if(cmds[2]=="@")
     name=zone;
-  else 
+  else
     name=DNSName(cmds[2])+zone;
 
   rr.qtype = DNSRecordContent::TypeToNumber(cmds[3]);
@@ -1212,7 +1227,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
     newrrs.push_back(rr);
   }
 
-  
+
   di.backend->replaceRRSet(di.id, name, rr.qtype, newrrs);
   // need to be explicit to bypass the ueberbackend cache!
   di.backend->lookup(rr.qtype, name, 0, di.id);
@@ -1224,7 +1239,7 @@ int addOrReplaceRecord(bool addOrReplace, const vector<string>& cmds) {
 }
 
 // delete-rrset zone name type
-int deleteRRSet(const std::string& zone_, const std::string& name_, const std::string& type_) 
+int deleteRRSet(const std::string& zone_, const std::string& name_, const std::string& type_)
 {
   UeberBackend B;
   DomainInfo di;
@@ -1237,7 +1252,7 @@ int deleteRRSet(const std::string& zone_, const std::string& name_, const std::s
   DNSName name;
   if(name_=="@")
     name=zone;
-  else 
+  else
     name=DNSName(name_)+zone;
 
   QType qt(QType::chartocode(type_.c_str()));
@@ -1302,16 +1317,16 @@ void testSpeed(DNSSECKeeper& dk, const DNSName& zone, const string& remote, int
   rr.ttl=3600;
   rr.auth=1;
   rr.qclass = QClass::IN;
-  
+
   UeberBackend db("key-only");
-  
+
   if ( ! db.backends.size() )
   {
     throw runtime_error("No backends available for DNSSEC key storage");
   }
 
   ChunkedSigningPipe csp(DNSName(zone), 1, cores);
-  
+
   vector<DNSZoneRecord> signatures;
   uint32_t rnd;
   unsigned char* octets = (unsigned char*)&rnd;
@@ -1319,11 +1334,11 @@ void testSpeed(DNSSECKeeper& dk, const DNSName& zone, const string& remote, int
   DTime dt;
   dt.set();
   for(unsigned int n=0; n < 100000; ++n) {
-    rnd = random();
-    snprintf(tmp, sizeof(tmp), "%d.%d.%d.%d", 
+    rnd = dns_random(UINT32_MAX);
+    snprintf(tmp, sizeof(tmp), "%d.%d.%d.%d",
       octets[0], octets[1], octets[2], octets[3]);
     rr.content=tmp;
-    
+
     snprintf(tmp, sizeof(tmp), "r-%u", rnd);
     rr.qname=DNSName(tmp)+zone;
     DNSZoneRecord dzr;
@@ -1369,7 +1384,7 @@ void verifyCrypto(const string& zone)
       toSign.push_back(DNSRecordContent::mastermake(rr.qtype.getCode(), 1, rr.content));
     }
   }
-  
+
   string msg = getMessageForRRSET(qname, rrc, toSign);
   cerr<<"Verify: "<<DNSCryptoKeyEngine::makeFromPublicKeyString(drc.d_algorithm, drc.d_key)->verify(msg, rrc.d_signature)<<endl;
   if(dsrc.d_digesttype) {
@@ -1871,7 +1886,7 @@ int addOrSetMeta(const DNSName& zone, const string& kind, const vector<string>&
 
 int main(int argc, char** argv)
 try
-{  
+{
   po::options_description desc("Allowed options");
   desc.add_options()
     ("help,h", "produce help message")
@@ -1949,13 +1964,15 @@ try
 #ifdef HAVE_P11KIT1
     cout<<"hsm assign ZONE ALGORITHM {ksk|zsk} MODULE SLOT PIN LABEL"<<endl<<
           "                                   Assign a hardware signing module to a ZONE"<<endl;
-    cout<<"hsm create-key ZONE KEY-ID [BITS]  Create a key using hardware signing module for ZONE (use assign first)"<<endl; 
+    cout<<"hsm create-key ZONE KEY-ID [BITS]  Create a key using hardware signing module for ZONE (use assign first)"<<endl;
     cout<<"                                   BITS defaults to 2048"<<endl;
 #endif
     cout<<"increase-serial ZONE               Increases the SOA-serial by 1. Uses SOA-EDIT"<<endl;
     cout<<"import-tsig-key NAME ALGORITHM KEY Import TSIG key"<<endl;
     cout<<"import-zone-key ZONE FILE          Import from a file a private key, ZSK or KSK"<<endl;
-    cout<<"       [active|inactive] [ksk|zsk]  Defaults to KSK and active"<<endl;
+    cout<<"       [active|inactive] [ksk|zsk] Defaults to KSK and active"<<endl;
+    cout<<"ipdecrypt IP passphrase/key [key]  Encrypt IP address using passphrase or base64 key"<<endl;
+    cout<<"ipencrypt IP passphrase/key [key]  Encrypt IP address using passphrase or base64 key"<<endl;
     cout<<"load-zone ZONE FILE                Load ZONE from FILE, possibly creating zone or atomically"<<endl;
     cout<<"                                   replacing contents"<<endl;
     cout<<"list-algorithms [with-backend]     List all DNSSEC algorithms supported, optionally also listing the crypto library used"<<endl;
@@ -1977,7 +1994,7 @@ try
     cout<<"set-presigned ZONE                 Use presigned RRSIGs from storage"<<endl;
     cout<<"set-publish-cdnskey ZONE           Enable sending CDNSKEY responses for ZONE"<<endl;
     cout<<"set-publish-cds ZONE [DIGESTALGOS] Enable sending CDS responses for ZONE, using DIGESTALGOS as signature algorithms"<<endl;
-    cout<<"                                   DIGESTALGOS should be a comma separated list of numbers, is is '1,2' by default"<<endl;
+    cout<<"                                   DIGESTALGOS should be a comma separated list of numbers, it is '1,2' by default"<<endl;
     cout<<"add-meta ZONE KIND VALUE           Add zone metadata, this adds to the existing KIND"<<endl;
     cout<<"                   [VALUE ...]"<<endl;
     cout<<"set-meta ZONE KIND [VALUE] [VALUE] Set zone metadata, optionally providing a value. *No* value clears meta"<<endl;
@@ -2004,6 +2021,25 @@ try
     return 1;
   }
 
+  if(cmds[0] == "ipencrypt" || cmds[0]=="ipdecrypt") {
+    if(cmds.size() < 3 || (cmds.size()== 4 && cmds[3]!="key")) {
+      cerr<<"Syntax: pdnsutil [ipencrypt|ipdecrypt] IP passphrase [key]"<<endl;
+      return 0;
+    }
+    string key;
+    if(cmds.size()==4) {
+      if(B64Decode(cmds[2], key) < 0) {
+        cerr<<"Could not parse '"<<cmds[3]<<"' as base64"<<endl;
+        return 0;
+      }
+    }
+    else {
+      key = makeIPCipherKey(cmds[2]);
+    }
+    exit(xcryptIP(cmds[0], cmds[1], key));
+  }
+
+
   if(cmds[0] == "test-algorithms") {
     if (testAlgorithms())
       return 0;
@@ -2071,8 +2107,8 @@ try
       return 0;
     }
     unsigned int exitCode = 0;
-    for(unsigned int n = 1; n < cmds.size(); ++n) 
-      if (!rectifyZone(dk, DNSName(cmds[n]))) 
+    for(unsigned int n = 1; n < cmds.size(); ++n)
+      if (!rectifyZone(dk, DNSName(cmds[n])))
        exitCode = 1;
     return exitCode;
   }
@@ -2240,7 +2276,7 @@ try
         active=false;
       } else if(pdns_stou(cmds[n])) {
         bits = pdns_stou(cmds[n]);
-      } else { 
+      } else {
         cerr<<"Unknown algorithm, key flag or size '"<<cmds[n]<<"'"<<endl;
         exit(EXIT_FAILURE);;
       }
@@ -2391,7 +2427,7 @@ try
       }
       dk.commitTransaction();
     }
-    
+
     for(const auto& zone : mustRectify)
       rectifyZone(dk, zone);
 
@@ -2489,7 +2525,7 @@ try
   else if(cmds[0]=="set-presigned") {
     if(cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil set-presigned ZONE"<<endl;
-      return 0; 
+      return 0;
     }
     if (! dk.setPresigned(DNSName(cmds[1]))) {
       cerr << "Could not set presigned for " << cmds[1] << " (is DNSSEC enabled in your backend?)" << endl;
@@ -2527,7 +2563,7 @@ try
   else if(cmds[0]=="unset-presigned") {
     if(cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil unset-presigned ZONE"<<endl;
-      return 0;  
+      return 0;
     }
     if (! dk.unsetPresigned(DNSName(cmds[1]))) {
       cerr << "Could not unset presigned on for " << cmds[1] << endl;
@@ -2573,7 +2609,7 @@ try
     if(narrow) {
       cerr<<"The '"<<zone<<"' zone uses narrow NSEC3, but calculating hash anyhow"<<endl;
     }
-      
+
     cout<<toBase32Hex(hashQNameWithSalt(ns3pr, record))<<endl;
   }
   else if(cmds[0]=="unset-nsec3") {
@@ -2597,7 +2633,7 @@ try
     unsigned int id=pdns_stou(cmds[2]);
     DNSSECPrivateKey dpk=dk.getKeyById(DNSName(zone), id);
     cout << dpk.getKey()->convertToISC() <<endl;
-  }  
+  }
   else if(cmds[0]=="increase-serial") {
     if (cmds.size() < 2) {
       cerr<<"Syntax: pdnsutil increase-serial ZONE"<<endl;
@@ -2626,14 +2662,14 @@ try
     DNSKEYRecordContent drc;
     shared_ptr<DNSCryptoKeyEngine> key(DNSCryptoKeyEngine::makeFromPEMString(drc, raw));
     dpk.setKey(key);
-    
+
     dpk.d_algorithm = pdns_stou(cmds[3]);
-    
+
     if(dpk.d_algorithm == DNSSECKeeper::RSASHA1NSEC3SHA1)
       dpk.d_algorithm = DNSSECKeeper::RSASHA1;
-      
+
     cerr<<(int)dpk.d_algorithm<<endl;
-    
+
     if(cmds.size() > 4) {
       if(pdns_iequals(cmds[4], "ZSK"))
         dpk.d_flags = 256;
@@ -2659,7 +2695,7 @@ try
     } else {
       cout<<std::to_string(id)<<endl;
     }
-    
+
   }
   else if(cmds[0]=="import-zone-key") {
     if(cmds.size() < 3) {
@@ -2673,11 +2709,11 @@ try
     shared_ptr<DNSCryptoKeyEngine> key(DNSCryptoKeyEngine::makeFromISCFile(drc, fname.c_str()));
     dpk.setKey(key);
     dpk.d_algorithm = drc.d_algorithm;
-    
+
     if(dpk.d_algorithm == DNSSECKeeper::RSASHA1NSEC3SHA1)
       dpk.d_algorithm = DNSSECKeeper::RSASHA1;
-    
-    dpk.d_flags = 257; 
+
+    dpk.d_flags = 257;
     bool active=true;
 
     for(unsigned int n = 3; n < cmds.size(); ++n) {
@@ -2689,10 +2725,10 @@ try
         active = 1;
       else if(pdns_iequals(cmds[n], "passive") || pdns_iequals(cmds[n], "inactive")) // passive eventually needs to be removed
         active = 0;
-      else { 
+      else {
         cerr<<"Unknown key flag '"<<cmds[n]<<"'"<<endl;
         exit(1);
-      }          
+      }
     }
     int64_t id;
     if (!dk.addKey(DNSName(zone), dpk, id, active)) {
@@ -2770,14 +2806,14 @@ try
         }
       }
     }
-    dpk->create(bits); 
-    dspk.setKey(dpk); 
-    dspk.d_algorithm = algorithm; 
-    dspk.d_flags = keyOrZone ? 257 : 256; 
+    dpk->create(bits);
+    dspk.setKey(dpk);
+    dspk.d_algorithm = algorithm;
+    dspk.d_flags = keyOrZone ? 257 : 256;
 
-    // print key to stdout 
-    cout << "Flags: " << dspk.d_flags << endl << 
-             dspk.getKey()->convertToISC() << endl; 
+    // print key to stdout
+    cout << "Flags: " << dspk.d_flags << endl <<
+             dspk.getKey()->convertToISC() << endl;
   } else if (cmds[0]=="generate-tsig-key") {
     string usage = "Syntax: " + cmds[0] + " name (hmac-md5|hmac-sha1|hmac-sha224|hmac-sha256|hmac-sha384|hmac-sha512)";
     if (cmds.size() < 3) {
@@ -2860,7 +2896,7 @@ try
         return 1;
      }
      UeberBackend B("default");
-     std::vector<std::string> meta; 
+     std::vector<std::string> meta;
      if (!B.getDomainMetadata(zname, metaKey, meta)) {
        cerr << "Failure enabling TSIG key " << name << " for " << zname << endl;
        return 1;
@@ -2942,7 +2978,7 @@ try
       for(const auto& each_meta: meta) {
         cout << each_meta.first << " = " << boost::join(each_meta.second, ", ") << endl;
       }
-    }  
+    }
     return 0;
 
   } else if (cmds[0]=="set-meta" || cmds[0]=="add-meta") {
@@ -3008,11 +3044,11 @@ try
          pub_label = label;
 
       std::ostringstream iscString;
-      iscString << "Private-key-format: v1.2" << std::endl << 
-        "Algorithm: " << algorithm << std::endl << 
+      iscString << "Private-key-format: v1.2" << std::endl <<
+        "Algorithm: " << algorithm << std::endl <<
         "Engine: " << module << std::endl <<
         "Slot: " << slot << std::endl <<
-        "PIN: " << pin << std::endl << 
+        "PIN: " << pin << std::endl <<
         "Label: " << label << std::endl <<
         "PubLabel: " << pub_label << std::endl;
 
@@ -3067,19 +3103,19 @@ try
         cerr << "Unable to create key for unknown zone '" << zone << "'" << std::endl;
         return 1;
       }
+
       id = pdns_stou(cmds[3]);
-      std::vector<DNSBackend::KeyData> keys; 
+      std::vector<DNSBackend::KeyData> keys;
       if (!B.getDomainKeys(zone, keys)) {
         cerr << "No keys found for zone " << zone << std::endl;
         return 1;
-      } 
+      }
 
       std::shared_ptr<DNSCryptoKeyEngine> dke = nullptr;
-      // lookup correct key      
+      // lookup correct key
       for(DNSBackend::KeyData &kd :  keys) {
         if (kd.id == id) {
-          // found our key. 
+          // found our key.
           DNSKEYRecordContent dkrc;
           dke = DNSCryptoKeyEngine::makeFromISCString(dkrc, kd.content);
         }
@@ -3108,7 +3144,7 @@ try
     }
 #else
     cerr<<"PKCS#11 support not enabled"<<endl;
-    return 1; 
+    return 1;
 #endif
   } else if (cmds[0] == "b2b-migrate") {
     if (cmds.size() < 3) {
index 312c4bfece6919adb7db0e84cf0f50e8014d2064..399ad132ae8a3bd6dca4d566d2ed03fa952584e4 100644 (file)
@@ -37,7 +37,7 @@ public:
   virtual int run(struct timeval* tv, int timeout=500) override;
   virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter, const struct timeval* ttd=nullptr) override;
   virtual void removeFD(callbackmap_t& cbmap, int fd) override;
 
   string getName() const override
@@ -60,20 +60,14 @@ static struct RegisterOurselves
   }
 } doIt;
 
-void PollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void PollFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  Callback cb;
-  cb.d_callback=toDo;
-  cb.d_parameter=parameter;
-  memset(&cb.d_ttd, 0, sizeof(cb.d_ttd));
-  if(cbmap.count(fd))
-    throw FDMultiplexerException("Tried to add fd "+std::to_string(fd)+ " to multiplexer twice");
-  cbmap[fd]=cb;
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 }
 
 void PollFDMultiplexer::removeFD(callbackmap_t& cbmap, int fd)
 {
-  if(d_inrun && d_iter->first==fd)  // trying to remove us!
+  if(d_inrun && d_iter->d_fd==fd)  // trying to remove us!
     ++d_iter;
 
   if(!cbmap.erase(fd))
@@ -87,13 +81,13 @@ vector<struct pollfd> PollFDMultiplexer::preparePollFD() const
 
   struct pollfd pollfd;
   for(const auto& cb : d_readCallbacks) {
-    pollfd.fd = cb.first;
+    pollfd.fd = cb.d_fd;
     pollfd.events = POLLIN;
     pollfds.push_back(pollfd);
   }
 
   for(const auto& cb : d_writeCallbacks) {
-    pollfd.fd = cb.first;
+    pollfd.fd = cb.d_fd;
     pollfd.events = POLLOUT;
     pollfds.push_back(pollfd);
   }
@@ -110,7 +104,7 @@ void PollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
     throw FDMultiplexerException("poll returned error: " + stringerror());
 
   for(const auto& pollfd : pollfds) {
-    if (pollfd.revents == POLLIN || pollfd.revents == POLLOUT) {
+    if (pollfd.revents & POLLIN || pollfd.revents & POLLOUT) {
       fds.push_back(pollfd.fd);
     }
   }
@@ -134,19 +128,19 @@ int PollFDMultiplexer::run(struct timeval* now, int timeout)
   d_inrun=true;
 
   for(const auto& pollfd : pollfds) {
-    if(pollfd.revents == POLLIN) {
+    if(pollfd.revents & POLLIN) {
       d_iter=d_readCallbacks.find(pollfd.fd);
     
       if(d_iter != d_readCallbacks.end()) {
-        d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
         continue; // so we don't refind ourselves as writable!
       }
     }
-    else if(pollfd.revents == POLLOUT) {
+    else if(pollfd.revents & POLLOUT) {
       d_iter=d_writeCallbacks.find(pollfd.fd);
     
       if(d_iter != d_writeCallbacks.end()) {
-        d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+        d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       }
     }
   }
index c6e91e60fc25d71e7a55e11890c36e8a59706120..39939b2f4fbc437f491ed0a7561613a18f00e657 100644 (file)
@@ -25,7 +25,7 @@ public:
 
   virtual int run(struct timeval* tv, int timeout=500);
 
-  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter);
+  virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd=nullptr);
   virtual void removeFD(callbackmap_t& cbmap, int fd);
   string getName()
   {
@@ -59,9 +59,9 @@ PortsFDMultiplexer::PortsFDMultiplexer() : d_pevents(new port_event_t[s_maxevent
     throw FDMultiplexerException("Setting up port: "+stringerror());
 }
 
-void PortsFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter)
+void PortsFDMultiplexer::addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const boost::any& parameter, const struct timeval* ttd)
 {
-  accountingAddFD(cbmap, fd, toDo, parameter);
+  accountingAddFD(cbmap, fd, toDo, parameter, ttd);
 
   if(port_associate(d_portfd, PORT_SOURCE_FD, fd, (&cbmap == &d_readCallbacks) ? POLLIN : POLLOUT, 0) < 0) {
     cbmap.erase(fd);
@@ -113,7 +113,7 @@ int PortsFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_readCallbacks.find(d_pevents[n].portev_object);
     
     if(d_iter != d_readCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       if(d_readCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
                         POLLIN, 0) < 0)
         throw FDMultiplexerException("Unable to add fd back to ports (read): "+stringerror());
@@ -123,7 +123,7 @@ int PortsFDMultiplexer::run(struct timeval* now, int timeout)
     d_iter=d_writeCallbacks.find(d_pevents[n].portev_object);
     
     if(d_iter != d_writeCallbacks.end()) {
-      d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
+      d_iter->d_callback(d_iter->d_fd, d_iter->d_parameter);
       if(d_writeCallbacks.count(d_pevents[n].portev_object) && port_associate(d_portfd, PORT_SOURCE_FD, d_pevents[n].portev_object, 
                         POLLOUT, 0) < 0)
         throw FDMultiplexerException("Unable to add fd back to ports (write): "+stringerror());
index b8f4e258c6355c6130e1412a0397d036aa03f2df..5f2b92fa08e42db31a8d5f65c2488bb9bf65556c 100644 (file)
@@ -57,7 +57,7 @@ try
  
     if(msg.empty()) {
       typedef map<string,string> all_t;
-      all_t all=getAllStatsMap();
+      all_t all=getAllStatsMap(StatComponent::Carbon);
       
       ostringstream str;
       time_t now=time(0);
index e73d7e7d23114fab22f191e8080c7e034ef38f3c..45b48cea8deb9dc598866ff423ac1de39cf74550 100644 (file)
@@ -114,6 +114,7 @@ static const oid dnssecAuthenticDataQueriesOID[] = { RECURSOR_STATS_OID, 95 };
 static const oid dnssecCheckDisabledQueriesOID[] = { RECURSOR_STATS_OID, 96 };
 static const oid variableResponsesOID[] = { RECURSOR_STATS_OID, 97 };
 static const oid specialMemoryUsageOID[] = { RECURSOR_STATS_OID, 98 };
+static const oid rebalancedQueriesOID[] = { RECURSOR_STATS_OID, 99 };
 
 static std::unordered_map<oid, std::string> s_statsMap;
 
@@ -148,21 +149,37 @@ static int handleCounter64Stats(netsnmp_mib_handler* handler,
   }
 }
 
-static void registerCounter64Stat(const char* name, const oid statOID[], size_t statOIDLength)
+static int handleDisabledCounter64Stats(netsnmp_mib_handler* handler,
+                                        netsnmp_handler_registration* reginfo,
+                                        netsnmp_agent_request_info* reqinfo,
+                                        netsnmp_request_info* requests)
+{
+  if (reqinfo->mode != MODE_GET) {
+    return SNMP_ERR_GENERR;
+  }
+
+  if (reginfo->rootoid_len != OID_LENGTH(questionsOID) + 1) {
+    return SNMP_ERR_GENERR;
+  }
+
+  return RecursorSNMPAgent::setCounter64Value(requests, 0);
+}
+
+static void registerCounter64Stat(const std::string& name, const oid statOID[], size_t statOIDLength)
 {
   if (statOIDLength != OID_LENGTH(questionsOID)) {
-    g_log<<Logger::Error<<"Invalid OID for SNMP Counter64 statistic "<<std::string(name)<<endl;
+    g_log<<Logger::Error<<"Invalid OID for SNMP Counter64 statistic "<<name<<endl;
     return;
   }
 
   if (s_statsMap.find(statOID[statOIDLength - 1]) != s_statsMap.end()) {
-    g_log<<Logger::Error<<"OID for SNMP Counter64 statistic "<<std::string(name)<<" has already been registered"<<endl;
+    g_log<<Logger::Error<<"OID for SNMP Counter64 statistic "<<name<<" has already been registered"<<endl;
     return;
   }
 
-  s_statsMap[statOID[statOIDLength - 1]] = name;
-  netsnmp_register_scalar(netsnmp_create_handler_registration(name,
-                                                              handleCounter64Stats,
+  s_statsMap[statOID[statOIDLength - 1]] = name.c_str();
+  netsnmp_register_scalar(netsnmp_create_handler_registration(name.c_str(),
+                                                              isStatBlacklisted(StatComponent::SNMP, name) ? handleCounter64Stats : handleDisabledCounter64Stats,
                                                               statOID,
                                                               statOIDLength,
                                                               HANDLER_CAN_RONLY));
@@ -304,6 +321,6 @@ RecursorSNMPAgent::RecursorSNMPAgent(const std::string& name, const std::string&
   registerCounter64Stat("policy-result-truncate", policyResultTruncateOID, OID_LENGTH(policyResultTruncateOID));
   registerCounter64Stat("policy-result-custom", policyResultCustomOID, OID_LENGTH(policyResultCustomOID));
   registerCounter64Stat("special-memory-usage", specialMemoryUsageOID, OID_LENGTH(specialMemoryUsageOID));
-
+  registerCounter64Stat("rebalanced-queries", rebalancedQueriesOID, OID_LENGTH(rebalancedQueriesOID));
 #endif /* HAVE_NET_SNMP */
 }
index 0f1dd7973e7f4413000ff1e20b1f1aaf11fc1e2f..2eb6fb8fd6cd5bf0cbf84dd6aa2ea8521ab236d7 100644 (file)
@@ -19,8 +19,9 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
-#ifndef PDNS_REC_CHANNEL
-#define PDNS_REC_CHANNEL
+
+#pragma once
+
 #include <string>
 #include <map>
 #include <vector>
@@ -63,7 +64,9 @@ public:
   std::string getAnswer(const std::string& question, func_t** func);
 };
 
-std::map<std::string, std::string> getAllStatsMap();
+enum class StatComponent { API, Carbon, RecControl, SNMP };
+
+std::map<std::string, std::string> getAllStatsMap(StatComponent component);
 extern pthread_mutex_t g_carbon_config_lock;
 std::vector<std::pair<DNSName, uint16_t> >* pleaseGetQueryRing();
 std::vector<std::pair<DNSName, uint16_t> >* pleaseGetServfailQueryRing();
@@ -76,5 +79,9 @@ std::vector<ComboAddress>* pleaseGetTimeouts();
 DNSName getRegisteredName(const DNSName& dom);
 std::atomic<unsigned long>* getDynMetric(const std::string& str);
 optional<uint64_t> getStatByName(const std::string& name);
+bool isStatBlacklisted(StatComponent component, const std::string& name);
+void blacklistStat(StatComponent component, const string& name);
+void blacklistStats(StatComponent component, const string& stats);
+
 void registerAllStats();
-#endif 
+
index 277d341063c5f4055f39e517dd112597ca4488dc..77bbd872e83c431d7425ef4a1157f8f4cd27d660 100644 (file)
 pthread_mutex_t g_carbon_config_lock=PTHREAD_MUTEX_INITIALIZER;
 
 static map<string, const uint32_t*> d_get32bitpointers;
-static map<string, const uint64_t*> d_get64bitpointers;
 static map<string, const std::atomic<uint64_t>*> d_getatomics;
 static map<string, function< uint64_t() > >  d_get64bitmembers;
 static pthread_mutex_t d_dynmetricslock = PTHREAD_MUTEX_INITIALIZER;
 static map<string, std::atomic<unsigned long>* > d_dynmetrics;
 
+static std::map<StatComponent, std::set<std::string>> s_blacklistedStats;
+
+bool isStatBlacklisted(StatComponent component, const string& name)
+{
+  return s_blacklistedStats[component].count(name) != 0;
+}
+
+void blacklistStat(StatComponent component, const string& name)
+{
+  s_blacklistedStats[component].insert(name);
+}
+
+void blacklistStats(StatComponent component, const string& stats)
+{
+  std::vector<std::string> blacklistedStats;
+  stringtok(blacklistedStats, stats, ", ");
+  auto& map = s_blacklistedStats[component];
+  for (const auto &st : blacklistedStats) {
+    map.insert(st);
+  }
+}
+
 static void addGetStat(const string& name, const uint32_t* place)
 {
   d_get32bitpointers[name]=place;
@@ -79,8 +100,6 @@ static optional<uint64_t> get(const string& name)
 
   if(d_get32bitpointers.count(name))
     return *d_get32bitpointers.find(name)->second;
-  if(d_get64bitpointers.count(name))
-    return *d_get64bitpointers.find(name)->second;
   if(d_getatomics.count(name))
     return d_getatomics.find(name)->second->load();
   if(d_get64bitmembers.count(name))
@@ -99,37 +118,44 @@ optional<uint64_t> getStatByName(const std::string& name)
   return get(name);
 }
 
-map<string,string> getAllStatsMap()
+map<string,string> getAllStatsMap(StatComponent component)
 {
   map<string,string> ret;
-  
+  const auto& blacklistMap = s_blacklistedStats.at(component);
+
   for(const auto& the32bits :  d_get32bitpointers) {
-    ret.insert(make_pair(the32bits.first, std::to_string(*the32bits.second)));
-  }
-  for(const auto& the64bits :  d_get64bitpointers) {
-    ret.insert(make_pair(the64bits.first, std::to_string(*the64bits.second)));
+    if (blacklistMap.count(the32bits.first) == 0) {
+      ret.insert(make_pair(the32bits.first, std::to_string(*the32bits.second)));
+    }
   }
   for(const auto& atomic :  d_getatomics) {
-    ret.insert(make_pair(atomic.first, std::to_string(atomic.second->load())));
+    if (blacklistMap.count(atomic.first) == 0) {
+      ret.insert(make_pair(atomic.first, std::to_string(atomic.second->load())));
+    }
   }
 
   for(const auto& the64bitmembers :  d_get64bitmembers) {
-    if(the64bitmembers.first == "cache-bytes" || the64bitmembers.first=="packetcache-bytes")
-      continue; // too slow for 'get-all'
-    if(the64bitmembers.first == "special-memory-usage")
-      continue; // too slow for 'get-all'
-    ret.insert(make_pair(the64bitmembers.first, std::to_string(the64bitmembers.second())));
+    if (blacklistMap.count(the64bitmembers.first) == 0) {
+      ret.insert(make_pair(the64bitmembers.first, std::to_string(the64bitmembers.second())));
+    }
   }
-  Lock l(&d_dynmetricslock);
-  for(const auto& a : d_dynmetrics)
-    ret.insert({a.first, std::to_string(*a.second)});
+
+  {
+    Lock l(&d_dynmetricslock);
+    for(const auto& a : d_dynmetrics) {
+      if (blacklistMap.count(a.first) == 0) {
+        ret.insert({a.first, std::to_string(*a.second)});
+      }
+    }
+  }
+
   return ret;
 }
 
-string getAllStats()
+static string getAllStats()
 {
   typedef map<string, string> varmap_t;
-  varmap_t varmap = getAllStatsMap();
+  varmap_t varmap = getAllStatsMap(StatComponent::RecControl);
   string ret;
   for(varmap_t::value_type& tup :  varmap) {
     ret += tup.first + "\t" + tup.second +"\n";
@@ -138,7 +164,7 @@ string getAllStats()
 }
 
 template<typename T>
-string doGet(T begin, T end)
+static string doGet(T begin, T end)
 {
   string ret;
 
@@ -153,7 +179,7 @@ string doGet(T begin, T end)
 }
 
 template<typename T>
-string doGetParameter(T begin, T end)
+string static doGetParameter(T begin, T end)
 {
   string ret;
   string parm;
@@ -206,7 +232,7 @@ static uint64_t* pleaseDumpThrottleMap(int fd)
 }
 
 template<typename T>
-string doDumpNSSpeeds(T begin, T end)
+static string doDumpNSSpeeds(T begin, T end)
 {
   T i=begin;
   string fname;
@@ -237,7 +263,7 @@ string doDumpNSSpeeds(T begin, T end)
 }
 
 template<typename T>
-string doDumpCache(T begin, T end)
+static string doDumpCache(T begin, T end)
 {
   T i=begin;
   string fname;
@@ -259,7 +285,7 @@ string doDumpCache(T begin, T end)
 }
 
 template<typename T>
-string doDumpEDNSStatus(T begin, T end)
+static string doDumpEDNSStatus(T begin, T end)
 {
   T i=begin;
   string fname;
@@ -281,7 +307,7 @@ string doDumpEDNSStatus(T begin, T end)
 }
 
 template<typename T>
-string doDumpRPZ(T begin, T end)
+static string doDumpRPZ(T begin, T end)
 {
   T i=begin;
 
@@ -320,7 +346,7 @@ string doDumpRPZ(T begin, T end)
 }
 
 template<typename T>
-string doDumpThrottleMap(T begin, T end)
+static string doDumpThrottleMap(T begin, T end)
 {
   T i=begin;
   string fname;
@@ -360,7 +386,7 @@ uint64_t* pleaseWipeAndCountNegCache(const DNSName& canon, bool subtree)
 
 
 template<typename T>
-string doWipeCache(T begin, T end)
+static string doWipeCache(T begin, T end)
 {
   vector<pair<DNSName, bool> > toWipe;
   for(T i=begin; i != end; ++i) {
@@ -391,7 +417,7 @@ string doWipeCache(T begin, T end)
 }
 
 template<typename T>
-string doSetCarbonServer(T begin, T end)
+static string doSetCarbonServer(T begin, T end)
 {
   Lock l(&g_carbon_config_lock);
   if(begin==end) {
@@ -424,7 +450,7 @@ string doSetCarbonServer(T begin, T end)
 }
 
 template<typename T>
-string doSetDnssecLogBogus(T begin, T end)
+static string doSetDnssecLogBogus(T begin, T end)
 {
   if(checkDNSSECDisabled())
     return "DNSSEC is disabled in the configuration, not changing the Bogus logging setting\n";
@@ -454,7 +480,7 @@ string doSetDnssecLogBogus(T begin, T end)
 }
 
 template<typename T>
-string doAddNTA(T begin, T end)
+static string doAddNTA(T begin, T end)
 {
   if(checkDNSSECDisabled())
     return "DNSSEC is disabled in the configuration, not adding a Negative Trust Anchor\n";
@@ -492,7 +518,7 @@ string doAddNTA(T begin, T end)
 }
 
 template<typename T>
-string doClearNTA(T begin, T end)
+static string doClearNTA(T begin, T end)
 {
   if(checkDNSSECDisabled())
     return "DNSSEC is disabled in the configuration, not removing a Negative Trust Anchor\n";
@@ -558,7 +584,7 @@ static string getNTAs()
 }
 
 template<typename T>
-string doAddTA(T begin, T end)
+static string doAddTA(T begin, T end)
 {
   if(checkDNSSECDisabled())
     return "DNSSEC is disabled in the configuration, not adding a Trust Anchor\n";
@@ -603,7 +629,7 @@ string doAddTA(T begin, T end)
 }
 
 template<typename T>
-string doClearTA(T begin, T end)
+static string doClearTA(T begin, T end)
 {
   if(checkDNSSECDisabled())
     return "DNSSEC is disabled in the configuration, not removing a Trust Anchor\n";
@@ -666,30 +692,59 @@ static string getTAs()
 }
 
 template<typename T>
-string setMinimumTTL(T begin, T end)
+static string setMinimumTTL(T begin, T end)
 {
-  if(end-begin != 1) 
+  if(end-begin != 1)
     return "Need to supply new minimum TTL number\n";
-  SyncRes::s_minimumTTL = pdns_stou(*begin);
-  return "New minimum TTL: " + std::to_string(SyncRes::s_minimumTTL) + "\n";
+  try {
+    SyncRes::s_minimumTTL = pdns_stou(*begin);
+    return "New minimum TTL: " + std::to_string(SyncRes::s_minimumTTL) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new minimum TTL number: " + std::string(e.what()) + "\n";
+  }
+}
+
+template<typename T>
+static string setMinimumECSTTL(T begin, T end)
+{
+  if(end-begin != 1)
+    return "Need to supply new ECS minimum TTL number\n";
+  try {
+    SyncRes::s_minimumECSTTL = pdns_stou(*begin);
+    return "New minimum ECS TTL: " + std::to_string(SyncRes::s_minimumECSTTL) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new ECS minimum TTL number: " + std::string(e.what()) + "\n";
+  }
 }
 
 template<typename T>
-string setMaxCacheEntries(T begin, T end)
+static string setMaxCacheEntries(T begin, T end)
 {
   if(end-begin != 1) 
     return "Need to supply new cache size\n";
-  g_maxCacheEntries = pdns_stou(*begin);
-  return "New max cache entries: " + std::to_string(g_maxCacheEntries) + "\n";
+  try {
+    g_maxCacheEntries = pdns_stou(*begin);
+    return "New max cache entries: " + std::to_string(g_maxCacheEntries) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new cache size: " + std::string(e.what()) + "\n";
+  }
 }
 
 template<typename T>
-string setMaxPacketCacheEntries(T begin, T end)
+static string setMaxPacketCacheEntries(T begin, T end)
 {
   if(end-begin != 1) 
     return "Need to supply new packet cache size\n";
-  g_maxPacketCacheEntries = pdns_stou(*begin);
-  return "New max packetcache entries: " + std::to_string(g_maxPacketCacheEntries) + "\n";
+  try {
+    g_maxPacketCacheEntries = pdns_stou(*begin);
+    return "New max packetcache entries: " + std::to_string(g_maxPacketCacheEntries) + "\n";
+  }
+  catch (const std::exception& e) {
+    return "Error parsing the new packet cache size: " + std::string(e.what()) + "\n";
+  }
 }
 
 
@@ -707,9 +762,51 @@ static uint64_t getUserTimeMsec()
   return (ru.ru_utime.tv_sec*1000ULL + ru.ru_utime.tv_usec/1000);
 }
 
+/* This is a pretty weird set of functions. To get per-thread cpu usage numbers,
+   we have to ask a thread over a pipe. We could do so surgically, so if you want to know about
+   thread 3, we pick pipe 3, but we lack that infrastructure.
+
+   We can however ask "execute this function on all threads and add up the results".
+   This is what the first function does using a custom object ThreadTimes, which if you add
+   to each other keeps filling the first one with CPU usage numbers
+*/
+
+static ThreadTimes* pleaseGetThreadCPUMsec()
+{
+  uint64_t ret=0;
+#ifdef RUSAGE_THREAD
+  struct rusage ru;
+  getrusage(RUSAGE_THREAD, &ru);
+  ret = (ru.ru_utime.tv_sec*1000ULL + ru.ru_utime.tv_usec/1000);
+  ret += (ru.ru_stime.tv_sec*1000ULL + ru.ru_stime.tv_usec/1000);
+#endif
+  return new ThreadTimes{ret};
+}
+
+/* Next up, when you want msec data for a specific thread, we check
+   if we recently executed pleaseGetThreadCPUMsec. If we didn't we do so
+   now and consult all threads.
+
+   We then answer you from the (re)fresh(ed) ThreadTimes.
+*/
+static uint64_t doGetThreadCPUMsec(int n)
+{
+  static std::mutex s_mut;
+  static time_t last = 0;
+  static ThreadTimes tt;
+
+  std::lock_guard<std::mutex> l(s_mut);
+  if(last != time(nullptr)) {
+   tt = broadcastAccFunction<ThreadTimes>(pleaseGetThreadCPUMsec);
+   last = time(nullptr);
+  }
+
+  return tt.times.at(n);
+}
+
 static uint64_t calculateUptime()
 {
-  return time(0) - g_stats.startupTime;
+  return time(nullptr) - g_stats.startupTime;
 }
 
 static string* pleaseGetCurrentQueries()
@@ -762,17 +859,18 @@ uint64_t* pleaseGetNegCacheSize()
   return new uint64_t(tmp);
 }
 
-uint64_t getNegCacheSize()
+static uint64_t getNegCacheSize()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetNegCacheSize);
 }
 
-uint64_t* pleaseGetFailedHostsSize()
+static uint64_t* pleaseGetFailedHostsSize()
 {
   uint64_t tmp=(SyncRes::getThrottledServersSize());
   return new uint64_t(tmp);
 }
-uint64_t getFailedHostsSize()
+
+static uint64_t getFailedHostsSize()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetFailedHostsSize);
 }
@@ -782,7 +880,7 @@ uint64_t* pleaseGetNsSpeedsSize()
   return new uint64_t(SyncRes::getNSSpeedsSize());
 }
 
-uint64_t getNsSpeedsSize()
+static uint64_t getNsSpeedsSize()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetNsSpeedsSize);
 }
@@ -802,24 +900,22 @@ uint64_t* pleaseGetCacheSize()
   return new uint64_t(t_RC ? t_RC->size() : 0);
 }
 
-uint64_t* pleaseGetCacheBytes()
+static uint64_t* pleaseGetCacheBytes()
 {
   return new uint64_t(t_RC ? t_RC->bytes() : 0);
 }
 
-
-uint64_t doGetCacheSize()
+static uint64_t doGetCacheSize()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetCacheSize);
 }
 
-uint64_t doGetAvgLatencyUsec()
+static uint64_t doGetAvgLatencyUsec()
 {
   return (uint64_t) g_stats.avgLatencyUsec;
 }
 
-
-uint64_t doGetCacheBytes()
+static uint64_t doGetCacheBytes()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetCacheBytes);
 }
@@ -829,7 +925,7 @@ uint64_t* pleaseGetCacheHits()
   return new uint64_t(t_RC ? t_RC->cacheHits : 0);
 }
 
-uint64_t doGetCacheHits()
+static uint64_t doGetCacheHits()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetCacheHits);
 }
@@ -839,55 +935,52 @@ uint64_t* pleaseGetCacheMisses()
   return new uint64_t(t_RC ? t_RC->cacheMisses : 0);
 }
 
-uint64_t doGetCacheMisses()
+static uint64_t doGetCacheMisses()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetCacheMisses);
 }
 
-
 uint64_t* pleaseGetPacketCacheSize()
 {
   return new uint64_t(t_packetCache ? t_packetCache->size() : 0);
 }
 
-uint64_t* pleaseGetPacketCacheBytes()
+static uint64_t* pleaseGetPacketCacheBytes()
 {
   return new uint64_t(t_packetCache ? t_packetCache->bytes() : 0);
 }
 
-
-uint64_t doGetPacketCacheSize()
+static uint64_t doGetPacketCacheSize()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetPacketCacheSize);
 }
 
-uint64_t doGetPacketCacheBytes()
+static uint64_t doGetPacketCacheBytes()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetPacketCacheBytes);
 }
 
-
 uint64_t* pleaseGetPacketCacheHits()
 {
   return new uint64_t(t_packetCache ? t_packetCache->d_hits : 0);
 }
 
-uint64_t doGetPacketCacheHits()
+static uint64_t doGetPacketCacheHits()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetPacketCacheHits);
 }
 
-uint64_t* pleaseGetPacketCacheMisses()
+static uint64_t* pleaseGetPacketCacheMisses()
 {
   return new uint64_t(t_packetCache ? t_packetCache->d_misses : 0);
 }
 
-uint64_t doGetPacketCacheMisses()
+static uint64_t doGetPacketCacheMisses()
 {
   return broadcastAccFunction<uint64_t>(pleaseGetPacketCacheMisses);
 }
 
-uint64_t doGetMallocated()
+static uint64_t doGetMallocated()
 {
   // this turned out to be broken
 /*  struct mallinfo mi = mallinfo();
@@ -979,13 +1072,13 @@ void registerAllStats()
   addGetStat("empty-queries", &g_stats.emptyQueriesCount);
   addGetStat("max-mthread-stack", &g_stats.maxMThreadStackUsage);
   
-  addGetStat("negcache-entries", boost::bind(getNegCacheSize));
-  addGetStat("throttle-entries", boost::bind(getThrottleSize)); 
+  addGetStat("negcache-entries", getNegCacheSize);
+  addGetStat("throttle-entries", getThrottleSize);
 
-  addGetStat("nsspeeds-entries", boost::bind(getNsSpeedsSize));
-  addGetStat("failed-host-entries", boost::bind(getFailedHostsSize));
+  addGetStat("nsspeeds-entries", getNsSpeedsSize);
+  addGetStat("failed-host-entries", getFailedHostsSize);
 
-  addGetStat("concurrent-queries", boost::bind(getConcurrentQueries)); 
+  addGetStat("concurrent-queries", getConcurrentQueries);
   addGetStat("security-status", &g_security_status);
   addGetStat("outgoing-timeouts", &SyncRes::s_outgoingtimeouts);
   addGetStat("outgoing4-timeouts", &SyncRes::s_outgoing4timeouts);
@@ -1031,6 +1124,9 @@ void registerAllStats()
   addGetStat("user-msec", getUserTimeMsec);
   addGetStat("sys-msec", getSysTimeMsec);
 
+  for(unsigned int n=0; n < g_numThreads; ++n)
+    addGetStat("cpu-msec-thread-"+std::to_string(n), boost::bind(&doGetThreadCPUMsec, n));
+
 #ifdef MALLOC_TRACE
   addGetStat("memory-allocs", boost::bind(&MallocTracer::getAllocs, g_mtracer, string()));
   addGetStat("memory-alloc-flux", boost::bind(&MallocTracer::getAllocFlux, g_mtracer, string()));
@@ -1050,6 +1146,19 @@ void registerAllStats()
   addGetStat("policy-result-nodata", &g_stats.policyResults[DNSFilterEngine::PolicyKind::NODATA]);
   addGetStat("policy-result-truncate", &g_stats.policyResults[DNSFilterEngine::PolicyKind::Truncate]);
   addGetStat("policy-result-custom", &g_stats.policyResults[DNSFilterEngine::PolicyKind::Custom]);
+
+  addGetStat("rebalanced-queries", &g_stats.rebalancedQueries);
+
+  /* make sure that the ECS stats are properly initialized */
+  SyncRes::clearECSStats();
+  for (size_t idx = 0; idx < SyncRes::s_ecsResponsesBySubnetSize4.size(); idx++) {
+    const std::string name = "ecs-v4-response-bits-" + std::to_string(idx + 1);
+    addGetStat(name, &(SyncRes::s_ecsResponsesBySubnetSize4.at(idx)));
+  }
+  for (size_t idx = 0; idx < SyncRes::s_ecsResponsesBySubnetSize6.size(); idx++) {
+    const std::string name = "ecs-v6-response-bits-" + std::to_string(idx + 1);
+    addGetStat(name, &(SyncRes::s_ecsResponsesBySubnetSize6.at(idx)));
+  }
 }
 
 static void doExitGeneric(bool nicely)
@@ -1329,6 +1438,7 @@ string RecursorControlParser::getAnswer(const string& question, RecursorControlP
 "reload-lua-script [filename]     (re)load Lua script\n"
 "reload-lua-config [filename]     (re)load Lua configuration file\n"
 "reload-zones                     reload all auth and forward zones\n"
+"set-ecs-minimum-ttl value        set ecs-minimum-ttl-override\n"
 "set-max-cache-entries value      set new maximum cache size\n"
 "set-max-packetcache-entries val  set new maximum packet cache size\n"      
 "set-minimum-ttl value            set minimum-ttl-override\n"
@@ -1499,6 +1609,10 @@ string RecursorControlParser::getAnswer(const string& question, RecursorControlP
     return reloadAuthAndForwards();
   }
 
+  if(cmd=="set-ecs-minimum-ttl") {
+    return setMinimumECSTTL(begin, end);
+  }
+
   if(cmd=="set-max-cache-entries") {
     return setMaxCacheEntries(begin, end);
   }
index 92c8ae5a27ed0b0eb3ae11fd64c662c64771f1e7..2d7052b1b1c15e1c9413189603e5f71335a162e1 100644 (file)
@@ -168,6 +168,7 @@ pdns_recursor_SOURCES = \
        ueberbackend.hh \
        unix_utility.cc \
        utility.hh \
+       uuid-utils.hh uuid-utils.cc \
        validate.cc validate.hh validate-recursor.cc validate-recursor.hh \
        version.cc version.hh \
        webserver.cc webserver.hh \
@@ -235,6 +236,7 @@ testrunner_SOURCES = \
        nsecrecords.cc \
        pdnsexception.hh \
        opensslsigners.cc opensslsigners.hh \
+       pollmplexer.cc \
        protobuf.cc protobuf.hh \
        qtype.cc qtype.hh \
        rcpgenerator.cc \
@@ -263,6 +265,7 @@ testrunner_SOURCES = \
        test-ixfr_cc.cc \
        test-misc_hh.cc \
        test-mtasker.cc \
+       test-mplexer.cc \
        test-negcache_cc.cc \
        test-packetcache_hh.cc \
        test-rcpgenerator_cc.cc \
@@ -335,16 +338,21 @@ endif
 
 if HAVE_FREEBSD
 pdns_recursor_SOURCES += kqueuemplexer.cc
+testrunner_SOURCES += kqueuemplexer.cc
 endif
 
 if HAVE_LINUX
 pdns_recursor_SOURCES += epollmplexer.cc
+testrunner_SOURCES += epollmplexer.cc
 endif
 
 if HAVE_SOLARIS
 pdns_recursor_SOURCES += \
        devpollmplexer.cc \
        portsmplexer.cc
+testrunner_SOURCES += \
+       devpollmplexer.cc \
+       portsmplexer.cc
 endif
 
 if HAVE_PROTOBUF
@@ -362,10 +370,6 @@ testrunner_LDADD += $(PROTOBUF_LIBS)
 testrunner$(OBJEXT): dnsmessage.pb.cc
 
 endif
-
-pdns_recursor_SOURCES += \
-       uuid-utils.hh uuid-utils.cc
-
 endif
 
 rec_control_SOURCES = \
index cc3368cce7dae6fc707066e7c7b9b24e529e843b..762901a0cb8ddd67229c5734c6f2b0c512684feb 100644 (file)
@@ -14,8 +14,8 @@ Starting with version 4.0.0, the PowerDNS recursor uses autotools and compiling
 make
 ```
 
-As for dependencies, Boost (http://boost.org/) and OpenSSL (https://openssl.org/)
-are required.
+As for dependencies, Boost (http://boost.org/), OpenSSL (https://openssl.org/),
+and Lua (https://www.lua.org/) are required.
 
 On most modern UNIX distributions, you can simply install 'boost' or
 'boost-dev' or 'boost-devel'. Otherwise, just download boost, and point the
index 2af4621d08d6daefe3406205d5f5e20e586246ba..6b9fac22e1bd65cfa66c5c1d0431a759301e194a 100644 (file)
@@ -817,6 +817,14 @@ specialMemoryUsage OBJECT-TYPE
         "Memory usage (more precise bbut expensive to retrieve)"
     ::= { stats 98 }
 
+rebalancedQueries OBJECT-TYPE
+    SYNTAX Counter64
+    MAX-ACCESS read-only
+    STATUS current
+    DESCRIPTION
+        "Number of queries re-distributed because the first selected worker thread was above the target load"
+    ::= { stats 99 }
+
 ---
 --- Traps / Notifications
 ---
@@ -957,6 +965,8 @@ recGroup OBJECT-GROUP
         dnssecAuthenticDataQueries,
         dnssecCheckDisabledQueries,
         variableResponses,
+        specialMemoryUsage,
+        rebalancedQueries,
         trapReason
     }
     STATUS current
index ce066fb7a31c9e139f60106c3ae3ec472e5b34b8..fbad126428bc35290896463ca950f6146df323de 100644 (file)
@@ -10,8 +10,9 @@ AC_CONFIG_MACRO_DIR([m4])
 AC_CONFIG_HEADERS([config.h])
 
 AC_CANONICAL_HOST
-: ${CFLAGS="-Wall -g -O2"}
-: ${CXXFLAGS="-Wall -g -O2"}
+# Add some default CFLAGS and CXXFLAGS, can be appended to using the environment variables
+CFLAGS="-Wall -g -O2 $CFLAGS"
+CXXFLAGS="-Wall -g -O2 $CXXFLAGS"
 
 AC_SUBST([pdns_configure_args],["$ac_configure_args"])
 AC_DEFINE_UNQUOTED([PDNS_CONFIG_ARGS],
@@ -81,16 +82,9 @@ AC_DEFUN([PDNS_SELECT_CONTEXT_IMPL], [
 
 PDNS_CHECK_CLOCK_GETTIME
 
-boost_required_version=1.35
-
 PDNS_WITH_PROTOBUF
-AS_IF([test "x$PROTOBUF_LIBS" != "x" -a x"$PROTOC" != "x"],
-  # The protobuf code needs boost::uuid, which is available from 1.42 onward
-  [AC_MSG_WARN([Bumping minimal Boost requirement to 1.42. To keep the requirement at 1.35, disable protobuf support])
-  boost_required_version=1.42]
-)
 
-BOOST_REQUIRE([$boost_required_version])
+BOOST_REQUIRE([1.42])
 
 # Check against flat_set header that requires boost >= 1.48
 BOOST_FIND_HEADER([boost/container/flat_set.hpp], [AC_MSG_NOTICE([boost::container::flat_set not available, will fallback to std::set])])
index a7843705a326ecd2e0b613e5c31dc786e685f008..6a13b76b50abff67f11e851b5625086ce38cb320 100644 (file)
@@ -1,11 +1,42 @@
 Changelogs for 4.1.x
 ====================
 
+.. changelog::
+  :version: 4.1.12
+  :released: 2nd of April 2019
+
+  .. change::
+    :tags: Bug Fixes, Internals
+    :pullreq: 7495
+    :tickets: 7494
+
+    Correctly interpret an empty AXFR response to an IXFR query.
+
+  .. change::
+    :tags: Improvements, Internals
+    :pullreq: 7647
+
+    Provide CPU usage statistics per thread (worker & distributor).
+
+  .. change::
+    :tags: Improvements, Internals, Performance
+    :pullreq: 7634
+    :tickets: 7507
+
+    Use a bounded load-balancing algo to distribute queries.
+
+  .. change::
+    :tags: Improvements, Internals
+    :pullreq: 7651
+    :tickets: 7631, 7572
+
+    Implement a configurable ECS cache limit so responses with an ECS scope more specific than a certain threshold and a TTL smaller than a specific threshold are not inserted into the records cache at all.
+
 .. changelog::
   :version: 4.1.11
   :released: 1st of February 2019
 
-  Since Spectre/Meltdown, system calls have become more expensive.  This made exporting a very high number of protobuf messages costly, which is addressed in this release by reducing the number of sycalls per message.
+  Since Spectre/Meltdown, system calls have become more expensive.  This made exporting a very high number of protobuf messages costly, which is addressed in this release by reducing the number of syscalls per message.
 
   .. change::
     :tags: Improvements
index d6b8dfe804981c2d7b531979fce9f898830cb63f..8fa4b018cdaba19ce3a1752361ba157cc67c16ed 100644 (file)
@@ -196,6 +196,9 @@ set-dnssec-log-bogus *SETTING*
     DNSSEC validation failures and to 'no' or 'off' to disable logging these
     failures.
 
+set-ecs-minimum-ttl *NUM*
+    Set ecs-minimum-ttl-override to *NUM*.
+
 set-max-cache-entries *NUM*
     Change the maximum number of entries in the DNS cache.  If reduced, the
     cache size will start shrinking to this number as part of the normal
index 51e2ab397d94ef2e649417874d8678953bbe17ef..c04ebbc01f0ce91ea9975b364c5feb2b03d3a6b0 100644 (file)
@@ -185,6 +185,10 @@ concurrent-queries
 ^^^^^^^^^^^^^^^^^^
 shows the number of MThreads currently   running
 
+cpu-msec-thread-n
+^^^^^^^^^^^^^^^^^
+shows the number of milliseconds spent in thread n. Available since 4.1.12.
+
 dlg-only-drops
 ^^^^^^^^^^^^^^
 number of records dropped because of :ref:`setting-delegation-only` setting
@@ -241,6 +245,18 @@ ecs-responses
 ^^^^^^^^^^^^^
 number of responses received from authoritative servers with an EDNS Client Subnet option we used (since 4.1)
 
+ecs-v4-response-bits-*
+^^^^^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.2.0
+
+number of responses received from authoritative servers with an IPv4 EDNS Client Subnet option we used, of this subnet size (1 to 32).
+
+ecs-v6-response-bits-*
+^^^^^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.2.0
+
+number of responses received from authoritative servers with an IPv6 EDNS Client Subnet option we used, of this subnet size (1 to 128).
+
 edns-ping-matches
 ^^^^^^^^^^^^^^^^^
 number of servers that sent a valid EDNS PING   response
@@ -387,6 +403,12 @@ questions
 ^^^^^^^^^
 counts all end-user initiated queries with the RD bit   set
 
+rebalanced-queries
+^^^^^^^^^^^^^^^^^^
+.. versionadded:: 4.1.12
+
+number of queries balanced to a different worker thread because the first selected one was above the target load configured with 'distribution-load-factor'
+
 resource-limits
 ^^^^^^^^^^^^^^^
 counts number of queries that could not be   performed because of resource limits
index 27e722fdb5896c8ad02a7f5ee684976cc7508c65..78286bb97be37e2d03549f45c0203d12970d41ff 100644 (file)
@@ -278,6 +278,25 @@ Do not log to syslog, only to stdout.
 Use this setting when running inside a supervisor that handles logging (like systemd).
 **Note**: do not use this setting in combination with `daemon`_ as all logging will disappear.
 
+.. _setting-distribution-load-factor:
+
+``distribution-load-factor``
+----------------------------
+.. versionadded:: 4.1.12
+
+-  Double
+-  Default: 0.0
+
+If `pdns-distributes-queries`_ is set and this setting is set to another value
+than 0, the distributor thread will use a bounded load-balancing algorithm while
+distributing queries to worker threads, making sure that no thread is assigned
+more queries than distribution-load-factor times the average number of queries
+currently processed by all the workers.
+For example, with a value of 1.25, no server should get more than 125 % of the
+average load. This helps making sure that all the workers have roughly the same
+share of queries, even if the incoming traffic is very skewed, with a larger
+number of requests asking for the same qname.
+
 .. _setting-distributor-threads:
 
 ``distributor-threads``
@@ -377,6 +396,18 @@ This defaults to not using the requestor address inside RFC1918 and similar "pri
 
 Number of bits of client IPv4 address to pass when sending EDNS Client Subnet address information.
 
+.. _setting-ecs-ipv4-cache-bits:
+
+``ecs-ipv4-cache-bits``
+-----------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 24
+
+Maximum number of bits of client IPv4 address used by the authoritative server (as indicated by the EDNS Client Subnet scope in the answer) for an answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-cache-limit-ttl``.
+That is, only if both the limits apply, the record will not be cached.
+
 .. _setting-ecs-ipv6-bits:
 
 ``ecs-ipv6-bits``
@@ -388,6 +419,41 @@ Number of bits of client IPv4 address to pass when sending EDNS Client Subnet ad
 
 Number of bits of client IPv6 address to pass when sending EDNS Client Subnet address information.
 
+.. _setting-ecs-ipv6-cache-bits:
+
+``ecs-ipv6-cache-bits``
+-----------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 56
+
+Maximum number of bits of client IPv6 address used by the authoritative server (as indicated by the EDNS Client Subnet scope in the answer) for an answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-cache-limit-ttl``.
+That is, only if both the limits apply, the record will not be cached.
+
+.. _setting-ecs-minimum-ttl-override:
+
+``ecs-minimum-ttl-override``
+----------------------------
+-  Integer
+-  Default: 0 (disabled)
+
+This setting artificially raises the TTLs of records in the ANSWER section of ECS-specific answers to be at least this long.
+While this is a gross hack, and violates RFCs, under conditions of DoS, it may enable you to continue serving your customers.
+Can be set at runtime using ``rec_control set-ecs-minimum-ttl 3600``.
+
+.. _setting-ecs-cache-limit-ttl:
+
+``ecs-cache-limit-ttl``
+-----------------------
+.. versionadded:: 4.1.12
+
+-  Integer
+-  Default: 0 (disabled)
+
+The minimum TTL for an ECS-specific answer to be inserted into the query cache. This condition applies in conjunction with ``ecs-ipv4-cache-bits`` or ``ecs-ipv6-cache-bits``.
+That is, only if both the limits apply, the record will not be cached.
+
 .. _setting-ecs-scope-zero-address:
 
 ``ecs-scope-zero-address``
@@ -1271,6 +1337,41 @@ Size of the stack per thread.
 Interval between logging statistical summary on recursor performance.
 Use 0 to disable.
 
+.. _setting-stats-api-blacklist:
+
+``stats-api-blacklist``
+-----------------------
+.. versionadded:: 4.2.0
+
+-  String
+-  Default: "cache-bytes, packetcache-bytes, ecs-v4-response-bits-*, ecs-v6-response-bits-*"
+
+A list of comma-separated statistic names, that are disabled when retrieving the complete list of statistics via the API for performance reasons.
+These statistics can still be retrieved individually by specifically asking for it.
+
+.. _setting-stats-carbon-blacklist:
+
+``stats-carbon-blacklist``
+--------------------------
+.. versionadded:: 4.2.0
+
+-  String
+-  Default: "cache-bytes, packetcache-bytes, ecs-v4-response-bits-*, ecs-v6-response-bits-*"
+
+A list of comma-separated statistic names, that are prevented from being exported via carbon for performance reasons.
+
+.. _setting-stats-rec-control-blacklist:
+
+``stats-rec-control-blacklist``
+-------------------------------
+.. versionadded:: 4.2.0
+
+-  String
+-  Default: "cache-bytes, packetcache-bytes, ecs-v4-response-bits-*, ecs-v6-response-bits-*"
+
+A list of comma-separated statistic names, that are disabled when retrieving the complete list of statistics via `rec_control get-all`, for performance reasons.
+These statistics can still be retrieved individually.
+
 .. _setting-stats-ringbuffer-entries:
 
 ``stats-ringbuffer-entries``
@@ -1281,6 +1382,17 @@ Use 0 to disable.
 Number of entries in the remotes ringbuffer, which keeps statistics on who is querying your server.
 Can be read out using ``rec_control top-remotes``.
 
+.. _setting-stats-snmp-blacklist:
+
+``stats-snmp-blacklist``
+------------------------
+.. versionadded:: 4.2.0
+
+-  String
+-  Default: "cache-bytes, packetcache-bytes, ecs-v4-response-bits-*, ecs-v6-response-bits-*"
+
+A list of comma-separated statistic names, that are prevented from being exported via SNMP, for performance reasons.
+
 .. _setting-tcp-fast-open:
 
 ``tcp-fast-open``
@@ -1512,6 +1624,47 @@ IP address for the webserver to listen on.
 
 These subnets are allowed to access the webserver.
 
+.. _setting-webserver-loglevel:
+
+``webserver-loglevel``
+----------------------
+.. versionadded:: 4.2.0
+
+-  String, one of "none", "normal", "detailed"
+
+The amount of logging the webserver must do. "none" means no useful webserver information will be logged.
+When set to "normal", the webserver will log a line per request that should be familiar::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+When set to "detailed", all information about the request and response are logged::
+
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Request Details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-encoding: gzip, deflate
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   accept-language: en-US,en;q=0.5
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   connection: keep-alive
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   dnt: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   host: 127.0.0.1:8081
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   upgrade-insecure-requests: 1
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   user-agent: Mozilla/5.0 (X11; Linux x86_64; rv:64.0) Gecko/20100101 Firefox/64.0
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  No body
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e Response details:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Headers:
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Connection: close
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Length: 49
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Content-Type: text/html; charset=utf-8
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   Server: PowerDNS/0.0.15896.0.gaba8bab3ab
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e  Full body: 
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e   <!html><title>Not Found</title><h1>Not Found</h1>
+  [webserver] e235780e-a5cf-415e-9326-9d33383e739e 127.0.0.1:55376 "GET /api/v1/servers/localhost/bla HTTP/1.1" 404 196
+
+The value between the hooks is a UUID that is generated for each request. This can be used to find all lines related to a single request.
+
+.. note::
+  The webserver logs these line on the NOTICE level. The :ref:`settings-loglevel` seting must be 5 or higher for these lines to end up in the log.
+
 .. _setting-webserver-password:
 
 ``webserver-password``
diff --git a/pdns/recursordist/test-mplexer.cc b/pdns/recursordist/test-mplexer.cc
new file mode 120000 (symlink)
index 0000000..f406267
--- /dev/null
@@ -0,0 +1 @@
+../test-mplexer.cc
\ No newline at end of file
index 8736f32aee43dac397c1f50c1e6487d43c2a8dbb..4b5e6297225b103aa743da7232ea85819c7bd0f5 100644 (file)
@@ -130,8 +130,12 @@ static void init(bool debug=false)
   SyncRes::s_doIPv6 = true;
   SyncRes::s_ecsipv4limit = 24;
   SyncRes::s_ecsipv6limit = 56;
+  SyncRes::s_ecsipv4cachelimit = 24;
+  SyncRes::s_ecsipv6cachelimit = 56;
+  SyncRes::s_ecscachelimitttl = 0;
   SyncRes::s_rootNXTrust = true;
   SyncRes::s_minimumTTL = 0;
+  SyncRes::s_minimumECSTTL = 0;
   SyncRes::s_serverID = "PowerDNS Unit Tests Server ID";
   SyncRes::clearEDNSLocalSubnets();
   SyncRes::addEDNSLocalSubnet("0.0.0.0/0");
@@ -151,6 +155,8 @@ static void init(bool debug=false)
   SyncRes::clearFailedServers();
   BOOST_CHECK_EQUAL(SyncRes::getFailedServersSize(), 0);
 
+  SyncRes::clearECSStats();
+
   auto luaconfsCopy = g_luaconfs.getCopy();
   luaconfsCopy.dfe.clear();
   luaconfsCopy.dsAnchors.clear();
@@ -1079,7 +1085,7 @@ BOOST_AUTO_TEST_CASE(test_glueless_referral) {
   BOOST_CHECK_EQUAL(ret[0].d_name, target);
 }
 
-BOOST_AUTO_TEST_CASE(test_edns_submask_by_domain) {
+BOOST_AUTO_TEST_CASE(test_edns_subnet_by_domain) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
 
@@ -1096,15 +1102,49 @@ BOOST_AUTO_TEST_CASE(test_edns_submask_by_domain) {
 
       BOOST_REQUIRE(srcmask);
       BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      if (isRootServer(ip)) {
+        setLWResult(res, 0, false, false, true);
+        addRecordToLW(res, domain, QType::NS, "a.gtld-servers.net.", DNSResourceRecord::AUTHORITY, 172800);
+        addRecordToLW(res, "a.gtld-servers.net.", QType::A, "192.0.2.1", DNSResourceRecord::ADDITIONAL, 3600);
+
+        /* this one did not use the ECS info */
+        srcmask = boost::none;
+
+        return 1;
+      } else if (ip == ComboAddress("192.0.2.1:53")) {
+
+        setLWResult(res, 0, true, false, false);
+        addRecordToLW(res, domain, QType::A, "192.0.2.2");
+
+        /* this one did, but only up to a precision of /16, not the full /24 */
+        srcmask = Netmask("192.0.0.0/16");
+
+        return 1;
+      }
+
       return 0;
     });
 
+  SyncRes::s_ecsqueries = 0;
+  SyncRes::s_ecsresponses = 0;
   vector<DNSRecord> ret;
   int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
-  BOOST_CHECK_EQUAL(res, RCode::ServFail);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_REQUIRE_EQUAL(ret.size(), 1);
+  BOOST_CHECK(ret[0].d_type == QType::A);
+  BOOST_CHECK_EQUAL(ret[0].d_name, target);
+  BOOST_CHECK_EQUAL(SyncRes::s_ecsqueries, 2);
+  BOOST_CHECK_EQUAL(SyncRes::s_ecsresponses, 1);
+  for (const auto& entry : SyncRes::s_ecsResponsesBySubnetSize4) {
+    BOOST_CHECK_EQUAL(entry.second, entry.first == 15 ? 1 : 0);
+  }
+  for (const auto& entry : SyncRes::s_ecsResponsesBySubnetSize6) {
+    BOOST_CHECK_EQUAL(entry.second, 0);
+  }
 }
 
-BOOST_AUTO_TEST_CASE(test_edns_submask_by_addr) {
+BOOST_AUTO_TEST_CASE(test_edns_subnet_by_addr) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
 
@@ -1139,12 +1179,22 @@ BOOST_AUTO_TEST_CASE(test_edns_submask_by_addr) {
       return 0;
     });
 
+  SyncRes::s_ecsqueries = 0;
+  SyncRes::s_ecsresponses = 0;
   vector<DNSRecord> ret;
   int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
   BOOST_CHECK_EQUAL(res, RCode::NoError);
   BOOST_REQUIRE_EQUAL(ret.size(), 1);
   BOOST_CHECK(ret[0].d_type == QType::A);
   BOOST_CHECK_EQUAL(ret[0].d_name, target);
+  BOOST_CHECK_EQUAL(SyncRes::s_ecsqueries, 1);
+  BOOST_CHECK_EQUAL(SyncRes::s_ecsresponses, 1);
+  for (const auto& entry : SyncRes::s_ecsResponsesBySubnetSize4) {
+    BOOST_CHECK_EQUAL(entry.second, 0);
+  }
+  for (const auto& entry : SyncRes::s_ecsResponsesBySubnetSize6) {
+    BOOST_CHECK_EQUAL(entry.second, entry.first == 55 ? 1 : 0);
+  }
 }
 
 BOOST_AUTO_TEST_CASE(test_ecs_use_requestor) {
@@ -2023,6 +2073,8 @@ BOOST_AUTO_TEST_CASE(test_skip_negcache_for_variable_response) {
         addRecordToLW(res, "powerdns.com.", QType::NS, "pdns-public-ns1.powerdns.com.", DNSResourceRecord::AUTHORITY, 172800);
         addRecordToLW(res, "pdns-public-ns1.powerdns.com.", QType::A, "192.0.2.1", DNSResourceRecord::ADDITIONAL, 3600);
 
+        srcmask = boost::none;
+
         return 1;
       } else if (ip == ComboAddress("192.0.2.1:53")) {
         if (domain == target) {
@@ -2052,6 +2104,204 @@ BOOST_AUTO_TEST_CASE(test_skip_negcache_for_variable_response) {
   BOOST_CHECK_EQUAL(SyncRes::getNegCacheSize(), 0);
 }
 
+BOOST_AUTO_TEST_CASE(test_ecs_cache_limit_allowed) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("www.powerdns.com.");
+
+  SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::s_ecsipv4cachelimit = 24;
+
+  sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 1);
+
+  /* should have been cached */
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_limit_no_ttl_limit_allowed) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("www.powerdns.com.");
+
+  SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::s_ecsipv4cachelimit = 16;
+
+  sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_CHECK_EQUAL(ret.size(), 1);
+
+  /* should have been cached because /24 is more specific than /16 but TTL limit is nof effective */
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_allowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 30;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have been cached */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_and_scope_allowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 100;
+    SyncRes::s_ecsipv4cachelimit = 24;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have been cached */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(test_ecs_cache_ttllimit_notallowed) {
+    std::unique_ptr<SyncRes> sr;
+    initSR(sr);
+
+    primeHints();
+
+    const DNSName target("www.powerdns.com.");
+
+    SyncRes::addEDNSDomain(DNSName("powerdns.com."));
+
+    EDNSSubnetOpts incomingECS;
+    incomingECS.source = Netmask("192.0.2.128/32");
+    sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+    SyncRes::s_ecscachelimitttl = 100;
+    SyncRes::s_ecsipv4cachelimit = 16;
+
+    sr->setAsyncCallback([target](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      setLWResult(res, 0, true, false, true);
+      addRecordToLW(res, target, QType::A, "192.0.2.1");
+
+      return 1;
+    });
+
+    const time_t now = sr->getNow().tv_sec;
+    vector<DNSRecord> ret;
+    int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+    BOOST_CHECK_EQUAL(res, RCode::NoError);
+    BOOST_CHECK_EQUAL(ret.size(), 1);
+
+    /* should have NOT been cached because TTL of 60 is too small and /24 is more specific than /16 */
+    const ComboAddress who("192.0.2.128");
+    vector<DNSRecord> cached;
+    BOOST_REQUIRE_LT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+    BOOST_REQUIRE_EQUAL(cached.size(), 0);
+}
+
+
 BOOST_AUTO_TEST_CASE(test_ns_speed) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
@@ -2300,6 +2550,75 @@ BOOST_AUTO_TEST_CASE(test_cache_min_max_ttl) {
   BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_maxcachettl);
 }
 
+BOOST_AUTO_TEST_CASE(test_cache_min_max_ecs_ttl) {
+  std::unique_ptr<SyncRes> sr;
+  initSR(sr);
+
+  primeHints();
+
+  const DNSName target("cacheecsttl.powerdns.com.");
+  const ComboAddress ns("192.0.2.1:53");
+
+  EDNSSubnetOpts incomingECS;
+  incomingECS.source = Netmask("192.0.2.128/32");
+  sr->setQuerySource(ComboAddress(), boost::optional<const EDNSSubnetOpts&>(incomingECS));
+  SyncRes::addEDNSDomain(target);
+
+  sr->setAsyncCallback([target,ns](const ComboAddress& ip, const DNSName& domain, int type, bool doTCP, bool sendRDQuery, int EDNS0Level, struct timeval* now, boost::optional<Netmask>& srcmask, boost::optional<const ResolveContext&> context, LWResult* res, bool* chained) {
+
+      BOOST_REQUIRE(srcmask);
+      BOOST_CHECK_EQUAL(srcmask->toString(), "192.0.2.0/24");
+
+      if (isRootServer(ip)) {
+
+        setLWResult(res, 0, false, false, true);
+        addRecordToLW(res, domain, QType::NS, "a.gtld-servers.net.", DNSResourceRecord::AUTHORITY, 172800);
+        addRecordToLW(res, "a.gtld-servers.net.", QType::A, ns.toString(), DNSResourceRecord::ADDITIONAL, 20);
+        srcmask = boost::none;
+
+        return 1;
+      } else if (ip == ns) {
+
+        setLWResult(res, 0, true, false, false);
+        addRecordToLW(res, domain, QType::A, "192.0.2.2", DNSResourceRecord::ANSWER, 10);
+
+        return 1;
+      }
+
+      return 0;
+    });
+
+  const time_t now = sr->getNow().tv_sec;
+  SyncRes::s_minimumTTL = 60;
+  SyncRes::s_minimumECSTTL = 120;
+  SyncRes::s_maxcachettl = 3600;
+
+  vector<DNSRecord> ret;
+  int res = sr->beginResolve(target, QType(QType::A), QClass::IN, ret);
+  BOOST_CHECK_EQUAL(res, RCode::NoError);
+  BOOST_REQUIRE_EQUAL(ret.size(), 1);
+  BOOST_CHECK_EQUAL(ret[0].d_ttl, SyncRes::s_minimumECSTTL);
+
+  const ComboAddress who("192.0.2.128");
+  vector<DNSRecord> cached;
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::A), true, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_EQUAL((cached[0].d_ttl - now), SyncRes::s_minimumECSTTL);
+
+  cached.clear();
+  BOOST_REQUIRE_GT(t_RC->get(now, target, QType(QType::NS), false, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_maxcachettl);
+
+  cached.clear();
+  BOOST_REQUIRE_GT(t_RC->get(now, DNSName("a.gtld-servers.net."), QType(QType::A), false, &cached, who), 0);
+  BOOST_REQUIRE_EQUAL(cached.size(), 1);
+  BOOST_REQUIRE_GT(cached[0].d_ttl, now);
+  BOOST_CHECK_LE((cached[0].d_ttl - now), SyncRes::s_minimumTTL);
+}
+
 BOOST_AUTO_TEST_CASE(test_cache_expired_ttl) {
   std::unique_ptr<SyncRes> sr;
   initSR(sr);
index 61b4c6b0087f43990b45f14690ec7e67da161365..ac4f75a4c82d38c7be22846046449842a5ed130a 100644 (file)
@@ -24,9 +24,9 @@ void ResponseStats::submitResponse(DNSPacket &p, bool udpOrTCP) {
 
   if(p.d.aa) {
     if (p.d.rcode==RCode::NXDomain)
-      S.ringAccount("nxdomain-queries",p.qdomain.toLogString()+"/"+p.qtype.getName());
-  } else if (p.isEmpty()) {
-    S.ringAccount("unauth-queries",p.qdomain.toLogString()+"/"+p.qtype.getName());
+      S.ringAccount("nxdomain-queries", p.qdomain, p.qtype);
+  } else if (p.d.rcode == RCode::Refused) {
+    S.ringAccount("unauth-queries", p.qdomain, p.qtype);
     S.ringAccount("remotes-unauth",p.d_remote);
   }
 
index 1a504ef30232bb6826d8606fbf7bba7075f36688..3f923cb1bd77391750e5a6736e259778566cb3f4 100644 (file)
@@ -791,6 +791,34 @@ struct NOPTest
 
 };
 
+struct StatRingDNSNameQTypeToStringTest
+{
+  explicit StatRingDNSNameQTypeToStringTest(const DNSName &name, const QType type) : d_name(name), d_type(type) {}
+
+  string getName() const { return "StatRing test with DNSName and QType to string"; }
+
+  void operator()() const {
+    S.ringAccount("testring", d_name.toLogString()+"/"+d_type.getName());
+  };
+
+  DNSName d_name;
+  QType d_type;
+};
+
+struct StatRingDNSNameQTypeTest
+{
+  explicit StatRingDNSNameQTypeTest(const DNSName &name, const QType type) : d_name(name), d_type(type) {}
+
+  string getName() const { return "StatRing test with DNSName and QType"; }
+
+  void operator()() const {
+    S.ringAccount("testringdnsname", d_name, d_type);
+  };
+
+  DNSName d_name;
+  QType d_type;
+};
+
 
 
 int main(int argc, char** argv)
@@ -876,6 +904,16 @@ try
   doRun(DNSNameParseTest());
   doRun(DNSNameRootTest());
 
+#ifndef RECURSOR
+  S.doRings();
+
+  S.declareRing("testring", "Just some ring where we'll account things");
+  doRun(StatRingDNSNameQTypeToStringTest(DNSName("example.com"), QType(1)));
+
+  S.declareDNSNameQTypeRing("testringdnsname", "Just some ring where we'll account things");
+  doRun(StatRingDNSNameQTypeTest(DNSName("example.com"), QType(1)));
+#endif
+
   cerr<<"Total runs: " << g_totalRuns<<endl;
 
 }
index 403611eff9b18b9204f7d50fbcde32eeb4e56bae..922315a5e21bc3aeeb2c1c4bd1c321b0d29030b1 100644 (file)
@@ -47,43 +47,42 @@ typedef int ProtocolType; //!< Supported protocol types
 //! Representation of a Socket and many of the Berkeley functions available
 class Socket : public boost::noncopyable
 {
-  Socket(int fd)
+  Socket(int fd): d_socket(fd)
   {
-    d_socket = fd;
-    d_buflen=4096;
-    d_buffer=new char[d_buflen];
   }
 
 public:
   //! Construct a socket of specified address family and socket type.
   Socket(int af, int st, ProtocolType pt=0)
   {
-    if((d_socket=socket(af,st, pt))<0)
+    if((d_socket=socket(af, st, pt))<0)
       throw NetworkError(strerror(errno));
     setCloseOnExec(d_socket);
+  }
 
-    d_buflen=4096;
-    d_buffer=new char[d_buflen];
+  Socket(Socket&& rhs): d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
+  {
+    rhs.d_socket = -1;
   }
 
   ~Socket()
   {
     try {
-      closesocket(d_socket);
+      if (d_socket != -1) {
+        closesocket(d_socket);
+      }
     }
     catch(const PDNSException& e) {
     }
-
-    delete[] d_buffer;
   }
 
   //! If the socket is capable of doing so, this function will wait for a connection
-  Socket *accept()
+  std::unique_ptr<Socket> accept()
   {
     struct sockaddr_in remote;
     socklen_t remlen=sizeof(remote);
     memset(&remote, 0, sizeof(remote));
-    int s=::accept(d_socket,(sockaddr *)&remote, &remlen);
+    int s=::accept(d_socket, reinterpret_cast<sockaddr *>(&remote), &remlen);
     if(s<0) {
       if(errno==EAGAIN)
         return nullptr;
@@ -91,21 +90,21 @@ public:
       throw NetworkError("Accepting a connection: "+string(strerror(errno)));
     }
 
-    return new Socket(s);
+    return std::unique_ptr<Socket>(new Socket(s));
   }
 
   //! Get remote address
   bool getRemote(ComboAddress &remote) {
     socklen_t remotelen=sizeof(remote);
-    return (getpeername(d_socket, (struct sockaddr *)&remote, &remotelen) >= 0);
+    return (getpeername(d_socket, reinterpret_cast<struct sockaddr *>(&remote), &remotelen) >= 0);
   }
 
   //! Check remote address against netmaskgroup ng
-  bool acl(NetmaskGroup &ng)
+  bool acl(const NetmaskGroup &ng)
   {
     ComboAddress remote;
     if (getRemote(remote))
-      return ng.match((ComboAddress *) &remote);
+      return ng.match(remote);
 
     return false;
   }
@@ -115,6 +114,7 @@ public:
   {
     ::setNonBlocking(d_socket);
   }
+
   //! Set the socket to blocking
   void setBlocking()
   {
@@ -125,35 +125,22 @@ public:
   {
     try {
       ::setReuseAddr(d_socket);
-    } catch (PDNSException &e) {
+    } catch (const PDNSException &e) {
       throw NetworkError(e.reason);
     }
   }
 
   //! Bind the socket to a specified endpoint
-  void bind(const ComboAddress &local)
+  void bind(const ComboAddress &local, bool reuseaddr=true)
   {
     int tmp=1;
-    if(setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR,(char*)&tmp,sizeof tmp)<0)
+    if(reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp)<0)
       throw NetworkError(string("Setsockopt failed: ")+strerror(errno));
 
-    if(::bind(d_socket,(struct sockaddr *)&local, local.getSocklen())<0)
+    if(::bind(d_socket, reinterpret_cast<const struct sockaddr *>(&local), local.getSocklen())<0)
       throw NetworkError("While binding: "+string(strerror(errno)));
   }
 
-#if 0
-  //! Bind the socket to a specified endpoint
-  void bind(const ComboAddress &ep)
-  {
-    ComboAddress local;
-    memset(reinterpret_cast<char *>(&local),0,sizeof(local));
-    local.sin_family=d_family;
-    local.sin_addr.s_addr=ep.address.byte;
-    local.sin_port=htons(ep.port);
-    
-    bind(local);
-  }
-#endif
   //! Connect the socket to a specified endpoint
   void connect(const ComboAddress &ep, int timeout=0)
   {
@@ -167,20 +154,22 @@ public:
       \param ep Will be filled with the origin of the datagram */
   void recvFrom(string &dgram, ComboAddress &ep)
   {
-    socklen_t remlen=sizeof(ep);
+    socklen_t remlen = sizeof(ep);
     ssize_t bytes;
-    if((bytes=recvfrom(d_socket, d_buffer, d_buflen, 0, (sockaddr *)&ep , &remlen)) <0)
+    d_buffer.resize(s_buflen);
+    if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&ep) , &remlen)) <0)
       throw NetworkError("After recvfrom: "+string(strerror(errno)));
     
-    dgram.assign(d_buffer,bytes);
+    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
   }
 
   bool recvFromAsync(string &dgram, ComboAddress &ep)
   {
     struct sockaddr_in remote;
-    socklen_t remlen=sizeof(remote);
+    socklen_t remlen = sizeof(remote);
     ssize_t bytes;
-    if((bytes=recvfrom(d_socket, d_buffer, d_buflen, 0, (sockaddr *)&remote, &remlen))<0) {
+    d_buffer.resize(s_buflen);
+    if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&remote), &remlen))<0) {
       if(errno!=EAGAIN) {
         throw NetworkError("After async recvfrom: "+string(strerror(errno)));
       }
@@ -188,7 +177,7 @@ public:
         return false;
       }
     }
-    dgram.assign(d_buffer,bytes);
+    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
     return true;
   }
 
@@ -196,7 +185,7 @@ public:
   //! For datagram sockets, send a datagram to a destination
   void sendTo(const char* msg, size_t len, const ComboAddress &ep)
   {
-    if(sendto(d_socket, msg, len, 0, (sockaddr *)&ep, ep.getSocklen())<0)
+    if(sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr *>(&ep), ep.getSocklen())<0)
       throw NetworkError("After sendto: "+string(strerror(errno)));
   }
 
@@ -233,9 +222,9 @@ public:
         throw NetworkError("Writing to a socket: "+string(strerror(errno)));
       if(!res)
         throw NetworkError("EOF on socket");
-      toWrite-=(size_t)res;
-      ptr+=(size_t)res;
-    }while(toWrite);
+      toWrite -= static_cast<size_t>(res);
+      ptr += static_cast<size_t>(res);
+    } while(toWrite);
 
   }
 
@@ -275,7 +264,7 @@ public:
   void writenWithTimeout(const void *buffer, size_t n, int timeout)
   {
     size_t bytes=n;
-    const char *ptr = (char*)buffer;
+    const char *ptr = reinterpret_cast<const char*>(buffer);
     ssize_t ret;
     while(bytes) {
       ret=::write(d_socket, ptr, bytes);
@@ -295,8 +284,8 @@ public:
         throw NetworkError("Did not fulfill TCP write due to EOF");
       }
 
-      ptr += (size_t) ret;
-      bytes -= (size_t) ret;
+      ptr += static_cast<size_t>(ret);
+      bytes -= static_cast<size_t>(ret);
     }
   }
 
@@ -325,19 +314,20 @@ public:
   //! Reads a block of data from the socket to a string
   void read(string &data)
   {
-    ssize_t res=::recv(d_socket,d_buffer,d_buflen,0);
+    d_buffer.resize(s_buflen);
+    ssize_t res=::recv(d_socket, &d_buffer[0], s_buflen, 0);
     if(res<0) 
       throw NetworkError("Reading from a socket: "+string(strerror(errno)));
-    data.assign(d_buffer,res);
+    data.assign(d_buffer, 0, static_cast<size_t>(res));
   }
 
   //! Reads a block of data from the socket to a block of memory
   size_t read(char *buffer, size_t bytes)
   {
-    ssize_t res=::recv(d_socket,buffer,bytes,0);
+    ssize_t res=::recv(d_socket, buffer, bytes, 0);
     if(res<0) 
       throw NetworkError("Reading from a socket: "+string(strerror(errno)));
-    return (size_t) res;
+    return static_cast<size_t>(res);
   }
 
   ssize_t readWithTimeout(char* buffer, size_t n, int timeout)
@@ -366,9 +356,9 @@ public:
   }
   
 private:
-  char *d_buffer;
+  static const size_t s_buflen{4096};
+  std::string d_buffer;
   int d_socket;
-  size_t d_buflen;
 };
 
 
index a4b0436f01113d51658a646d09e7be949d7d6a46..c4f46a5b1f8fc70a62cb183f2d7fa7097c34c86d 100644 (file)
@@ -218,7 +218,7 @@ vector<pair<T, unsigned int> >StatRing<T,Comp>::get() const
 
 void StatBag::declareRing(const string &name, const string &help, unsigned int size)
 {
-  d_rings[name]=StatRing<string>(size);
+  d_rings[name]=StatRing<string, CIStringCompare>(size);
   d_rings[name].setHelp(help);
 }
 
@@ -228,21 +228,33 @@ void StatBag::declareComboRing(const string &name, const string &help, unsigned
   d_comborings[name].setHelp(help);
 }
 
+void StatBag::declareDNSNameQTypeRing(const string &name, const string &help, unsigned int size)
+{
+  d_dnsnameqtyperings[name] = StatRing<std::tuple<DNSName, QType> >(size);
+  d_dnsnameqtyperings[name].setHelp(help);
+}
+
 
 vector<pair<string, unsigned int> > StatBag::getRing(const string &name)
 {
-  if(d_rings.count(name))
+  if(d_rings.count(name)) {
     return d_rings[name].get();
-  else {
+  }
+  vector<pair<string, unsigned int> > ret;
+
+  if (d_comborings.count(name)) {
     typedef pair<SComboAddress, unsigned int> stor_t;
     vector<stor_t> raw =d_comborings[name].get();
-    vector<pair<string, unsigned int> > ret;
     for(const stor_t& stor :  raw) {
       ret.push_back(make_pair(stor.first.ca.toString(), stor.second));
     }
-    return ret;
+  } else if(d_dnsnameqtyperings.count(name)) {
+    auto raw = d_dnsnameqtyperings[name].get();
+    for (auto const &e : raw) {
+      ret.push_back(make_pair(std::get<0>(e.first).toLogString() + "/" + std::get<1>(e.first).getName(), e.second));
+    }
   }
-    
+  return ret;
 }
 
 template<typename T, typename Comp>
@@ -256,16 +268,20 @@ void StatBag::resetRing(const string &name)
 {
   if(d_rings.count(name))
     d_rings[name].reset();
-  else
+  if(d_comborings.count(name))
     d_comborings[name].reset();
+  if(d_dnsnameqtyperings.count(name))
+    d_dnsnameqtyperings[name].reset();
 }
 
 void StatBag::resizeRing(const string &name, unsigned int newsize)
 {
   if(d_rings.count(name))
     d_rings[name].resize(newsize);
-  else
+  if(d_comborings.count(name))
     d_comborings[name].resize(newsize);
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].resize(newsize);
 }
 
 
@@ -273,33 +289,42 @@ unsigned int StatBag::getRingSize(const string &name)
 {
   if(d_rings.count(name))
     return d_rings[name].getSize();
-  else
+  if(d_comborings.count(name))
     return d_comborings[name].getSize();
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].getSize();
+  return 0;
 }
 
 string StatBag::getRingTitle(const string &name)
 {
   if(d_rings.count(name))
     return d_rings[name].getHelp();
-  else 
+  if(d_comborings.count(name))
     return d_comborings[name].getHelp();
+  if(d_dnsnameqtyperings.count(name))
+    return d_dnsnameqtyperings[name].getHelp();
+  return "";
 }
 
 vector<string>StatBag::listRings()
 {
   vector<string> ret;
-  for(map<string,StatRing<string> >::const_iterator i=d_rings.begin();i!=d_rings.end();++i)
+  for(auto i=d_rings.begin();i!=d_rings.end();++i)
     ret.push_back(i->first);
-  for(map<string,StatRing<SComboAddress> >::const_iterator i=d_comborings.begin();i!=d_comborings.end();++i)
+  for(auto i=d_comborings.begin();i!=d_comborings.end();++i)
     ret.push_back(i->first);
+  for(const auto &i : d_dnsnameqtyperings)
+    ret.push_back(i.first);
 
   return ret;
 }
 
 bool StatBag::ringExists(const string &name)
 {
-  return d_rings.count(name) || d_comborings.count(name);
+  return d_rings.count(name) || d_comborings.count(name) || d_dnsnameqtyperings.count(name);
 }
 
-template class StatRing<std::string>;
+template class StatRing<std::string, CIStringCompare>;
 template class StatRing<SComboAddress>;
+template class StatRing<std::tuple<DNSName, QType> >;
index b92c6a4371030dddbcde49794dd4b4fd60131808..2c8ca2850dfd4d7ed7caa23cda8ac7fc8b3aa385 100644 (file)
@@ -64,8 +64,9 @@ class StatBag
 {
   map<string, AtomicCounter *> d_stats;
   map<string, string> d_keyDescrips;
-  map<string,StatRing<string> >d_rings;
+  map<string,StatRing<string, CIStringCompare> >d_rings;
   map<string,StatRing<SComboAddress> >d_comborings;
+  map<string,StatRing<std::tuple<DNSName, QType> > >d_dnsnameqtyperings;
   typedef boost::function<uint64_t(const std::string&)> func_t;
   typedef map<string, func_t> funcstats_t;
   funcstats_t d_funcstats;
@@ -79,6 +80,7 @@ public:
 
   void declareRing(const string &name, const string &title, unsigned int size=10000);
   void declareComboRing(const string &name, const string &help, unsigned int size=10000);
+  void declareDNSNameQTypeRing(const string &name, const string &help, unsigned int size=10000);
   vector<pair<string, unsigned int> >getRing(const string &name);
   string getRingTitle(const string &name);
   void ringAccount(const char* name, const string &item)
@@ -98,6 +100,14 @@ public:
       d_comborings[name].account(item);
     }
   }
+  void ringAccount(const char* name, const DNSName &dnsname, const QType &qtype)
+  {
+    if(d_doRings) {
+      if(!d_dnsnameqtyperings.count(name))
+       throw runtime_error("Attempting to account to non-existent dnsname+qtype ring '"+std::string(name)+"'");
+      d_dnsnameqtyperings[name].account(std::make_tuple(dnsname, qtype));
+    }
+  }
 
   void doRings()
   {
index 39c20c0fc986302c559376842c25e90d9b845a86..ac8d78cb3035febf0d8446518c267a55bde5a4d2 100644 (file)
@@ -60,6 +60,10 @@ public:
   Stat print(unsigned int depth=0, Stat newstat=Stat(), bool silent=false) const;
   typedef boost::function<void(const StatNode*, const Stat& selfstat, const Stat& childstat)> visitor_t;
   void visit(visitor_t visitor, Stat& newstat, unsigned int depth=0) const;
+  bool empty() const
+  {
+    return children.empty() && s.remotes.empty();
+  }
   typedef std::map<std::string,StatNode, CIStringCompare> children_t;
   children_t children;
 
index 3cf44f16a37127c762b864f709fe2e8894691d76..1b71d3c190b1ae2b27b3b308415aa292421e0dfe 100644 (file)
@@ -55,10 +55,12 @@ unsigned int SyncRes::s_maxqperq;
 unsigned int SyncRes::s_maxtotusec;
 unsigned int SyncRes::s_maxdepth;
 unsigned int SyncRes::s_minimumTTL;
+unsigned int SyncRes::s_minimumECSTTL;
 unsigned int SyncRes::s_packetcachettl;
 unsigned int SyncRes::s_packetcacheservfailttl;
 unsigned int SyncRes::s_serverdownmaxfails;
 unsigned int SyncRes::s_serverdownthrottletime;
+unsigned int SyncRes::s_ecscachelimitttl;
 std::atomic<uint64_t> SyncRes::s_authzonequeries;
 std::atomic<uint64_t> SyncRes::s_queries;
 std::atomic<uint64_t> SyncRes::s_outgoingtimeouts;
@@ -72,8 +74,14 @@ std::atomic<uint64_t> SyncRes::s_nodelegated;
 std::atomic<uint64_t> SyncRes::s_unreachables;
 std::atomic<uint64_t> SyncRes::s_ecsqueries;
 std::atomic<uint64_t> SyncRes::s_ecsresponses;
+std::map<uint8_t, std::atomic<uint64_t>> SyncRes::s_ecsResponsesBySubnetSize4;
+std::map<uint8_t, std::atomic<uint64_t>> SyncRes::s_ecsResponsesBySubnetSize6;
+
 uint8_t SyncRes::s_ecsipv4limit;
 uint8_t SyncRes::s_ecsipv6limit;
+uint8_t SyncRes::s_ecsipv4cachelimit;
+uint8_t SyncRes::s_ecsipv6cachelimit;
+
 bool SyncRes::s_doIPv6;
 bool SyncRes::s_nopacketcache;
 bool SyncRes::s_rootNXTrust;
@@ -2412,7 +2420,31 @@ RCode::rcodes_ SyncRes::updateCacheFromRecords(unsigned int depth, LWResult& lwr
        - NS, A and AAAA (used for infra queries)
     */
     if (i->first.type != QType::NSEC3 && (i->first.type == QType::DS || i->first.type == QType::NS || i->first.type == QType::A || i->first.type == QType::AAAA || isAA || wasForwardRecurse)) {
-      t_RC->replace(d_now.tv_sec, i->first.name, QType(i->first.type), i->second.records, i->second.signatures, authorityRecs, i->first.type == QType::DS ? true : isAA, i->first.place == DNSResourceRecord::ANSWER ? ednsmask : boost::none, recordState);
+
+      bool doCache = true;
+      if (i->first.place == DNSResourceRecord::ANSWER && ednsmask) {
+        // If ednsmask is relevant, we do not want to cache if the scope prefix length is large and TTL is small
+        if (SyncRes::s_ecscachelimitttl > 0) {
+          bool manyMaskBits = (ednsmask->isIpv4() && ednsmask->getBits() > SyncRes::s_ecsipv4cachelimit) ||
+            (ednsmask->isIpv6() && ednsmask->getBits() > SyncRes::s_ecsipv6cachelimit);
+
+          if (manyMaskBits) {
+            uint32_t minttl = UINT32_MAX;
+            for (const auto &it : i->second.records) {
+              if (it.d_ttl < minttl)
+                minttl = it.d_ttl;
+            }
+            bool ttlIsSmall = minttl < SyncRes::s_ecscachelimitttl + d_now.tv_sec;
+            if (ttlIsSmall) {
+              // Case: many bits and ttlIsSmall
+              doCache = false;
+            }
+          }
+        }
+      }
+      if (doCache) {
+        t_RC->replace(d_now.tv_sec, i->first.name, QType(i->first.type), i->second.records, i->second.signatures, authorityRecs, i->first.type == QType::DS ? true : isAA, i->first.place == DNSResourceRecord::ANSWER ? ednsmask : boost::none, recordState);
+      }
     }
 
     if(i->first.place == DNSResourceRecord::ANSWER && ednsmask)
@@ -2700,6 +2732,14 @@ bool SyncRes::doResolveAtThisIP(const std::string& prefix, const DNSName& qname,
     if(ednsmask) {
       s_ecsresponses++;
       LOG(prefix<<qname<<": Received EDNS Client Subnet Mask "<<ednsmask->toString()<<" on response"<<endl);
+      if (ednsmask->getBits() > 0) {
+        if (ednsmask->isIpv4()) {
+          ++SyncRes::s_ecsResponsesBySubnetSize4.at(ednsmask->getBits()-1);
+        }
+        else {
+          ++SyncRes::s_ecsResponsesBySubnetSize6.at(ednsmask->getBits()-1);
+        }
+      }
     }
   }
 
@@ -2806,6 +2846,16 @@ bool SyncRes::processAnswer(unsigned int depth, LWResult& lwr, const DNSName& qn
     }
   }
 
+  /* if the answer is ECS-specific, a minimum TTL is set for this kind of answers
+     and it's higher than the global minimum TTL */
+  if (ednsmask && s_minimumECSTTL > 0 && (s_minimumTTL == 0 || s_minimumECSTTL > s_minimumTTL)) {
+    for(auto& rec : lwr.d_records) {
+      if (rec.d_place == DNSResourceRecord::ANSWER) {
+        rec.d_ttl = max(rec.d_ttl, s_minimumECSTTL);
+      }
+    }
+  }
+
   bool needWildcardProof = false;
   unsigned int wildcardLabelsCount;
   *rcode = updateCacheFromRecords(depth, lwr, qname, qtype, auth, wasForwarded, ednsmask, state, needWildcardProof, wildcardLabelsCount, sendRDQuery);
index a19f21e54e6d9ead23d52b2238e582f1690e5f08..a28ca22c57a34c0a35d5cc362cc07c11ed217551 100644 (file)
@@ -558,6 +558,20 @@ public:
     s_ecsScopeZero.source = scopeZeroMask;
   }
 
+  static void clearECSStats()
+  {
+    s_ecsqueries.store(0);
+    s_ecsresponses.store(0);
+
+    for (size_t idx = 0; idx < 32; idx++) {
+      SyncRes::s_ecsResponsesBySubnetSize4[idx].store(0);
+    }
+
+    for (size_t idx = 0; idx < 128; idx++) {
+      SyncRes::s_ecsResponsesBySubnetSize6[idx].store(0);
+    }
+  }
+
   explicit SyncRes(const struct timeval& now);
 
   int beginResolve(const DNSName &qname, const QType &qtype, uint16_t qclass, vector<DNSRecord>&ret);
@@ -686,9 +700,12 @@ public:
   static std::atomic<uint64_t> s_unreachables;
   static std::atomic<uint64_t> s_ecsqueries;
   static std::atomic<uint64_t> s_ecsresponses;
+  static std::map<uint8_t, std::atomic<uint64_t>> s_ecsResponsesBySubnetSize4;
+  static std::map<uint8_t, std::atomic<uint64_t>> s_ecsResponsesBySubnetSize6;
 
   static string s_serverID;
   static unsigned int s_minimumTTL;
+  static unsigned int s_minimumECSTTL;
   static unsigned int s_maxqperq;
   static unsigned int s_maxtotusec;
   static unsigned int s_maxdepth;
@@ -699,8 +716,11 @@ public:
   static unsigned int s_packetcacheservfailttl;
   static unsigned int s_serverdownmaxfails;
   static unsigned int s_serverdownthrottletime;
+  static unsigned int s_ecscachelimitttl;
   static uint8_t s_ecsipv4limit;
   static uint8_t s_ecsipv6limit;
+  static uint8_t s_ecsipv4cachelimit;
+  static uint8_t s_ecsipv6cachelimit;
   static bool s_doIPv6;
   static bool s_noEDNSPing;
   static bool s_noEDNS;
@@ -940,6 +960,7 @@ struct RecursorStats
   std::atomic<uint64_t> dnssecValidations; // should be the sum of all dnssecResult* stats
   std::map<vState, std::atomic<uint64_t> > dnssecResults;
   std::map<DNSFilterEngine::PolicyKind, std::atomic<uint64_t> > policyResults;
+  std::atomic<uint64_t> rebalancedQueries{0};
 };
 
 //! represents a running TCP/IP client session
@@ -1019,3 +1040,14 @@ void doCarbonDump(void*);
 void primeHints(void);
 
 extern __thread struct timeval g_now;
+
+struct ThreadTimes
+{
+  uint64_t msec;
+  vector<uint64_t> times;
+  ThreadTimes& operator+=(const ThreadTimes& rhs)
+  {
+    times.push_back(rhs.msec);
+    return *this;
+  }
+};
index 0d5bfa514eea42333c21f7fea3c86341bc356068..e14c535d859a24d3065ec6b937bf6f599f655383 100644 (file)
@@ -4,12 +4,18 @@
 
 #include "misc.hh"
 
+enum class IOState { Done, NeedRead, NeedWrite };
+
 class TLSConnection
 {
 public:
   virtual ~TLSConnection() { }
+  virtual void doHandshake() = 0;
+  virtual IOState tryHandshake() = 0;
   virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
+  virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
+  virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
   virtual void close() = 0;
 
 protected:
@@ -153,12 +159,14 @@ private:
 class TCPIOHandler
 {
 public:
+
   TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
   {
     if (ctx) {
       d_conn = ctx->getConnection(d_socket, timeout, now);
     }
   }
+
   ~TCPIOHandler()
   {
     if (d_conn) {
@@ -168,6 +176,15 @@ public:
       shutdown(d_socket, SHUT_RDWR);
     }
   }
+
+  IOState tryHandshake()
+  {
+    if (d_conn) {
+      return d_conn->tryHandshake();
+    }
+    return IOState::Done;
+  }
+
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
   {
     if (d_conn) {
@@ -176,6 +193,81 @@ public:
       return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
     }
   }
+
+  /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+     Updates pos everytime a successful read occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
+  {
+    if (buffer.size() < (pos + toRead)) {
+      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
+    }
+
+    if (d_conn) {
+      return d_conn->tryRead(buffer, pos, toRead);
+    }
+
+    size_t got = 0;
+    do {
+      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+      if (res == 0) {
+        throw runtime_error("EOF while reading message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedRead;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      got += static_cast<size_t>(res);
+    }
+    while (got < toRead);
+
+    return IOState::Done;
+  }
+
+  /* Tries to write exactly toWrite bytes from the buffer, starting at position pos.
+     Updates pos everytime a successful write occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
+  {
+    if (d_conn) {
+      return d_conn->tryWrite(buffer, pos, toWrite);
+    }
+
+    size_t sent = 0;
+    do {
+      ssize_t res = ::write(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toWrite - sent);
+      if (res == 0) {
+        throw runtime_error("EOF while sending message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedWrite;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      sent += static_cast<size_t>(res);
+    }
+    while (sent < toWrite);
+
+    return IOState::Done;
+  }
+
   size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
   {
     if (d_conn) {
index 706305c87aa718be343b0308dee1ce78b9383ade..9c0a7b80c11256ca94fd0f56833d20e51f7b64b0 100644 (file)
@@ -1352,7 +1352,7 @@ void TCPNameserver::thread()
 
       int sock=-1;
       for(const pollfd& pfd :  d_prfds) {
-        if(pfd.revents == POLLIN) {
+        if(pfd.revents & POLLIN) {
           sock = pfd.fd;
           remote.sin4.sin_family = AF_INET6;
           addrlen=remote.getSocklen();
index 872592eadbf377c0ace3a900acb9cab799bc0eb3..f34968e7416f8a3f65739b823b08262c3bf1b79a 100644 (file)
@@ -1433,8 +1433,7 @@ BOOST_AUTO_TEST_CASE(test_isEDNSOptionInOpt) {
       return false;
     }
 
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
+    if (optLen < optRecordMinimumSize) {
       return false;
     }
 
index 5b1fbc0f346f7f8dc113bb7058c9342cfd03cdec..38d35dfdf3e0516979add623b7d54e29ca70423b 100644 (file)
@@ -40,7 +40,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       pwR.getHeader()->ra = 1;
       pwR.getHeader()->qr = 1;
       pwR.getHeader()->id = pwQ.getHeader()->id;
-      pwR.startRecord(a, QType::A, 100, QClass::IN, DNSResourceRecord::ANSWER);
+      pwR.startRecord(a, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
       pwR.xfr32BitInt(0x01020304);
       pwR.commit();
       uint16_t responseLen = response.size();
@@ -86,13 +86,13 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime);
       bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK);
       if (found == true) {
-        PC.expungeByName(a);
-        deleted++;
+        auto removed = PC.expungeByName(a);
+        BOOST_CHECK_EQUAL(removed, 1);
+        deleted += removed;
       }
     }
     BOOST_CHECK_EQUAL(PC.getSize(), counter - skipped - deleted);
 
-
     size_t matches=0;
     vector<DNSResourceRecord> entry;
     size_t expected=counter-skipped-deleted;
@@ -111,10 +111,15 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
         matches++;
       }
     }
-    BOOST_CHECK_EQUAL(matches, expected);
 
-    PC.expungeByName(DNSName(" hello"), QType::ANY, true);
+    /* in the unlikely event that the test took so long that the entries did expire.. */
+    auto expired = PC.purgeExpired();
+    BOOST_CHECK_EQUAL(matches + expired, expected);
+
+    auto remaining = PC.getSize();
+    auto removed = PC.expungeByName(DNSName(" hello"), QType::ANY, true);
     BOOST_CHECK_EQUAL(PC.getSize(), 0);
+    BOOST_CHECK_EQUAL(removed, remaining);
   }
   catch(PDNSException& e) {
     cerr<<"Had error: "<<e.reason<<endl;
diff --git a/pdns/test-ipcrypt_cc.cc b/pdns/test-ipcrypt_cc.cc
new file mode 100644 (file)
index 0000000..a3eae8e
--- /dev/null
@@ -0,0 +1,70 @@
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+#include <boost/test/unit_test.hpp>
+#include "ipcipher.hh"
+#include "misc.hh"
+
+using namespace boost;
+
+BOOST_AUTO_TEST_SUITE(test_ipcrypt_hh)
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt4)
+{
+  ComboAddress ca("127.0.0.1");
+  std::string key="0123456789ABCDEF";
+  auto encrypted = encryptCA(ca, key);
+
+  auto decrypted = decryptCA(encrypted, key);
+  BOOST_CHECK_EQUAL(ca.toString(), decrypted.toString());
+}
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt4_vector)
+{
+  vector<pair<string,string>>  tests{   // test vector from https://github.com/veorq/ipcrypt
+    {{"127.0.0.1"},{"114.62.227.59"}},
+    {{"8.8.8.8"},  {"46.48.51.50"}},
+    {{"1.2.3.4"},  {"171.238.15.199"}}};
+
+  std::string key="some 16-byte key";
+
+  for(const auto& p : tests) {
+    auto encrypted = encryptCA(ComboAddress(p.first), key);
+    BOOST_CHECK_EQUAL(encrypted.toString(), p.second);
+    auto decrypted = decryptCA(encrypted, key);
+    BOOST_CHECK_EQUAL(decrypted.toString(), p.first);
+  }
+
+  // test from Frank Denis' test.cc
+  ComboAddress ip("192.168.69.42"), out, dec;
+  string key2;
+  for(int n=0; n<16; ++n)
+    key2.append(1, (char)n+1);
+
+  for (unsigned int i = 0; i < 100000000UL; i++) {
+    out=encryptCA(ip, key2);
+    //    dec=decryptCA(out, key2);
+    // BOOST_CHECK(ip==dec);
+    ip=out;
+  }
+
+  ComboAddress expected("93.155.197.186");
+
+  BOOST_CHECK_EQUAL(ip.toString(), expected.toString());
+}
+
+
+BOOST_AUTO_TEST_CASE(test_ipcrypt6)
+{
+  ComboAddress ca("::1");
+  std::string key="0123456789ABCDEF";
+  auto encrypted = encryptCA(ca, key);
+
+  auto decrypted = decryptCA(encrypted, key);
+  BOOST_CHECK_EQUAL(ca.toString(), decrypted.toString());
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/pdns/test-mplexer.cc b/pdns/test-mplexer.cc
new file mode 100644 (file)
index 0000000..8a7412f
--- /dev/null
@@ -0,0 +1,182 @@
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <thread>
+#include <boost/test/unit_test.hpp>
+
+#include "mplexer.hh"
+#include "misc.hh"
+
+BOOST_AUTO_TEST_SUITE(mplexer)
+
+BOOST_AUTO_TEST_CASE(test_MPlexer) {
+  auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+  BOOST_REQUIRE(mplexer != nullptr);
+
+  struct timeval now;
+  int ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+
+  std::vector<int> readyFDs;
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_CHECK_EQUAL(readyFDs.size(), 0);
+
+  auto timeouts = mplexer->getTimeouts(now);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+
+  int pipes[2];
+  int res = pipe(pipes);
+  BOOST_REQUIRE_EQUAL(res, 0);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(pipes[0]), true);
+  BOOST_REQUIRE_EQUAL(setNonBlocking(pipes[1]), true);
+
+  /* let's declare a TTD that expired 5s ago */
+  struct timeval ttd = now;
+  ttd.tv_sec -= 5;
+
+  bool writeCBCalled = false;
+  auto writeCB = [](int fd, FDMultiplexer::funcparam_t param) {
+                        auto calledPtr = boost::any_cast<bool*>(param);
+                        BOOST_REQUIRE(calledPtr != nullptr);
+                        *calledPtr = true;
+                 };
+  mplexer->addWriteFD(pipes[1],
+                      writeCB,
+                      &writeCBCalled,
+                      &ttd);
+  /* we can't add it twice */
+  BOOST_CHECK_THROW(mplexer->addWriteFD(pipes[1],
+                                        writeCB,
+                                        &writeCBCalled,
+                                        &ttd),
+                    FDMultiplexerException);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), pipes[1]);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  /* no read timeouts */
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+  /* but we should have a write one */
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  /* can't remove from the wrong type of FD */
+  BOOST_CHECK_THROW(mplexer->removeReadFD(pipes[1]), FDMultiplexerException);
+  mplexer->removeWriteFD(pipes[1]);
+  /* can't remove a non-existing FD */
+  BOOST_CHECK_THROW(mplexer->removeWriteFD(pipes[0]), FDMultiplexerException);
+  BOOST_CHECK_THROW(mplexer->removeWriteFD(pipes[1]), FDMultiplexerException);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 0);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+
+  bool readCBCalled = false;
+  auto readCB = [](int fd, FDMultiplexer::funcparam_t param) {
+                        auto calledPtr = boost::any_cast<bool*>(param);
+                        BOOST_REQUIRE(calledPtr != nullptr);
+                        *calledPtr = true;
+                };
+  mplexer->addReadFD(pipes[0],
+                      readCB,
+                      &readCBCalled,
+                      &ttd);
+
+  /* not ready for reading yet */
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 0);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 0);
+  BOOST_CHECK_EQUAL(readCBCalled, false);
+
+  /* let's make the pipe readable */
+  BOOST_REQUIRE_EQUAL(write(pipes[1], "0", 1), 1);
+
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 1);
+  BOOST_CHECK_EQUAL(readyFDs.at(0), pipes[0]);
+
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 1);
+  BOOST_CHECK_EQUAL(readCBCalled, true);
+
+  /* add back the write FD */
+  mplexer->addWriteFD(pipes[1],
+                      writeCB,
+                      &writeCBCalled,
+                      &ttd);
+
+  /* both should be available */
+  readyFDs.clear();
+  mplexer->getAvailableFDs(readyFDs, 0);
+  BOOST_REQUIRE_EQUAL(readyFDs.size(), 2);
+
+  readCBCalled = false;
+  writeCBCalled = false;
+  ready = mplexer->run(&now, 100);
+  BOOST_CHECK_EQUAL(ready, 2);
+  BOOST_CHECK_EQUAL(readCBCalled, true);
+  BOOST_CHECK_EQUAL(writeCBCalled, true);
+
+  /* both the read and write FD should be reported */
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[0]);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  struct timeval past = ttd;
+  /* so five seconds before the actual TTD */
+  past.tv_sec -= 5;
+
+  /* no read timeouts */
+  timeouts = mplexer->getTimeouts(past, false);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+  /* and we should not have a write one either */
+  timeouts = mplexer->getTimeouts(past, true);
+  BOOST_CHECK_EQUAL(timeouts.size(), 0);
+
+  /* update the timeouts to now, they should not be reported anymore */
+  mplexer->setReadTTD(pipes[0], now, 0);
+  mplexer->setWriteTTD(pipes[1], now, 0);
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 0);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 0);
+
+  /* put it back into the past */
+  mplexer->setReadTTD(pipes[0], now, -5);
+  mplexer->setWriteTTD(pipes[1], now, -5);
+  timeouts = mplexer->getTimeouts(now, false);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[0]);
+  timeouts = mplexer->getTimeouts(now, true);
+  BOOST_REQUIRE_EQUAL(timeouts.size(), 1);
+  BOOST_CHECK_EQUAL(timeouts.at(0).first, pipes[1]);
+
+  mplexer->removeReadFD(pipes[0]);
+  mplexer->removeWriteFD(pipes[1]);
+
+  /* clean up */
+  close(pipes[0]);
+  close(pipes[1]);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
index e2e1066e19a8d2cdf51fea962fc8387cab3645d6..bd16702662aed97e411faa048d24cbc721f8963b 100644 (file)
@@ -156,9 +156,11 @@ try
     DNSPacket r(false);
     r.parse((char*)&pak[0], pak.size());
 
-    /* this step is necessary to get a valid hash */
-    DNSPacket cached(false);
-    g_PC->get(&q, &cached);
+    /* this step is necessary to get a valid hash
+       we directly compute the hash instead of querying the
+       cache because 1/ it's faster 2/ no deferred-lookup issues
+    */
+    q.setHash(g_PC->canHashPacket(q.getString()));
 
     const unsigned int maxTTL = 3600;
     g_PC->insert(&q, &r, maxTTL);
@@ -212,6 +214,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
       pthread_join(tid[i], &res);
 
     BOOST_CHECK_EQUAL(PC.size() + S.read("deferred-packetcache-inserts"), 400000);
+    BOOST_CHECK_EQUAL(S.read("deferred-packetcache-lookup"), 0);
     BOOST_CHECK_SMALL(1.0*S.read("deferred-packetcache-inserts"), 10000.0);
 
     for(int i=0; i < 4; ++i)
@@ -224,9 +227,12 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheThreaded) {
     cerr<<"Hits: "<<S.read("packetcache-hit")<<endl;
     cerr<<"Deferred inserts: "<<S.read("deferred-packetcache-inserts")<<endl;
     cerr<<"Deferred lookups: "<<S.read("deferred-packetcache-lookup")<<endl;
+    cerr<<g_PCmissing<<endl;
+    cerr<<PC.size()<<endl;
 */
+
     BOOST_CHECK_EQUAL(g_PCmissing + S.read("packetcache-hit"), 400000);
-    BOOST_CHECK_GT(S.read("deferred-packetcache-inserts") + S.read("deferred-packetcache-lookup"), g_PCmissing);
+    BOOST_CHECK_EQUAL(S.read("deferred-packetcache-inserts") + S.read("deferred-packetcache-lookup"), g_PCmissing);
   }
   catch(PDNSException& e) {
     cerr<<"Had error: "<<e.reason<<endl;
index 5e856e1be4248a84ce5280789df96285e44aee99..814dc640962f261b11a907f3c48124e9c9753953 100644 (file)
@@ -5,6 +5,7 @@
 #include "config.h"
 #endif
 #include <boost/test/unit_test.hpp>
+#include "arguments.hh"
 #include "dnswriter.hh"
 #include "dnsrecords.hh"
 #include "dns_random.hh"
@@ -24,12 +25,15 @@ BOOST_AUTO_TEST_CASE(test_recPacketCacheSimple) {
   uint32_t ttd=3600;
   BOOST_CHECK_EQUAL(rpc.size(), 0);
 
+  ::arg().set("rng")="auto";
+  ::arg().set("entropy-source")="/dev/urandom";
+
   DNSName qname("www.powerdns.com");
   vector<uint8_t> packet;
   DNSPacketWriter pw(packet, qname, QType::A);
   pw.getHeader()->rd=true;
   pw.getHeader()->qr=false;
-  pw.getHeader()->id=random();
+  pw.getHeader()->id=dns_random(UINT16_MAX);
   string qpacket((const char*)&packet[0], packet.size());
   pw.startRecord(qname, QType::A, ttd);
 
@@ -68,7 +72,7 @@ BOOST_AUTO_TEST_CASE(test_recPacketCacheSimple) {
 
   pw2.getHeader()->rd=true;
   pw2.getHeader()->qr=false;
-  pw2.getHeader()->id=random();
+  pw2.getHeader()->id=dns_random(UINT16_MAX);
   qpacket.assign((const char*)&packet[0], packet.size());
 
   found = rpc.getResponsePacket(tag, qpacket, time(nullptr), &fpacket, &age, &qhash);
@@ -96,12 +100,15 @@ BOOST_AUTO_TEST_CASE(test_recPacketCache_Tags) {
   uint32_t ttd=3600;
   BOOST_CHECK_EQUAL(rpc.size(), 0);
 
+  ::arg().set("rng")="auto";
+  ::arg().set("entropy-source")="/dev/urandom";
+
   DNSName qname("www.powerdns.com");
   vector<uint8_t> packet;
   DNSPacketWriter pw(packet, qname, QType::A);
   pw.getHeader()->rd=true;
   pw.getHeader()->qr=false;
-  pw.getHeader()->id=random();
+  pw.getHeader()->id=dns_random(UINT16_MAX);
   string qpacket(reinterpret_cast<const char*>(&packet[0]), packet.size());
   pw.startRecord(qname, QType::A, ttd);
 
index f06d90cb602df11353bd0350a2ebd3772600a68b..6e5d817a09d3e81016ffcc3b3d691acfae617408 100644 (file)
@@ -212,20 +212,14 @@ int Utility::makeUidNumeric(const string &username)
   return newuid;
 }
 
-
-// Returns a random number.
-long int Utility::random( void )
-{
-  return rand();
-}
-
 // Sets the random seed.
-void Utility::srandom( unsigned int seed )
+void Utility::srandom(void)
 {
-  ::srandom(seed);
+  struct timeval tv;
+  gettimeofday(&tv, 0);
+  ::srandom(tv.tv_sec ^ tv.tv_usec ^ getpid());
 }
 
-
 // Writes a vector.
 int Utility::writev(int socket, const iovec *vector, size_t count )
 {
index 2672a1826135fc973ff5b6f47882adae199a04e8..024fc089fb111a5b6f66f239e60d2f32ebb39cca 100644 (file)
@@ -128,11 +128,9 @@ public:
 
   //! Writes a vector.
   static int writev( Utility::sock_t socket, const iovec *vector, size_t count );
-  //! Returns a random number.
-  static long int random( void );
 
   //! Sets the random seed.
-  static void srandom( unsigned int seed );
+  static void srandom(void);
 
   //! Drops the program's group privileges.
   static void dropGroupPrivs( int uid, int gid );
index 979cc70851bf3be2f2c265649b048499536fd728..a168559e178505720390c4f6c9957c6f8bc8ce17 100644 (file)
 #include "dns.hh"
 #include "base64.hh"
 #include "json.hh"
+#include "uuid-utils.hh"
 #include <yahttp/router.hpp>
 
 json11::Json HttpRequest::json()
 {
   string err;
   if(this->body.empty()) {
-    g_log<<Logger::Debug<<"HTTP: JSON document expected in request body, but body was empty" << endl;
+    g_log<<Logger::Debug<<logprefix<<"JSON document expected in request body, but body was empty" << endl;
     throw HttpBadRequestException();
   }
   json11::Json doc = json11::Json::parse(this->body, err);
   if (doc.is_null()) {
-    g_log<<Logger::Debug<<"HTTP: parsing of JSON document failed:" << err << endl;
+    g_log<<Logger::Debug<<logprefix<<"parsing of JSON document failed:" << err << endl;
     throw HttpBadRequestException();
   }
   return doc;
@@ -130,13 +131,13 @@ static void apiWrapper(WebServer::HandlerFunction handler, HttpRequest* req, Htt
   resp->headers["access-control-allow-origin"] = "*";
 
   if (apikey.empty()) {
-    g_log<<Logger::Error<<"HTTP API Request \"" << req->url.path << "\": Authentication failed, API Key missing in config" << endl;
+    g_log<<Logger::Error<<req->logprefix<<"HTTP API Request \"" << req->url.path << "\": Authentication failed, API Key missing in config" << endl;
     throw HttpUnauthorizedException("X-API-Key");
   }
   bool auth_ok = req->compareHeader("x-api-key", apikey) || req->getvars["api-key"] == apikey;
   
   if (!auth_ok) {
-    g_log<<Logger::Error<<"HTTP Request \"" << req->url.path << "\": Authentication by API Key failed" << endl;
+    g_log<<Logger::Error<<req->logprefix<<"HTTP Request \"" << req->url.path << "\": Authentication by API Key failed" << endl;
     throw HttpUnauthorizedException("X-API-Key");
   }
 
@@ -178,7 +179,7 @@ static void webWrapper(WebServer::HandlerFunction handler, HttpRequest* req, Htt
   if (!password.empty()) {
     bool auth_ok = req->compareAuthorization(password);
     if (!auth_ok) {
-      g_log<<Logger::Debug<<"HTTP Request \"" << req->url.path << "\": Web Authentication failed" << endl;
+      g_log<<Logger::Debug<<req->logprefix<<"HTTP Request \"" << req->url.path << "\": Web Authentication failed" << endl;
       throw HttpUnauthorizedException("Basic");
     }
   }
@@ -204,11 +205,11 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
 
   try {
     if (!req.complete) {
-      g_log<<Logger::Debug<<"HTTP: Incomplete request" << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"Incomplete request" << endl;
       throw HttpBadRequestException();
     }
 
-    g_log<<Logger::Debug<<"HTTP: Handling request \"" << req.url.path << "\"" << endl;
+    g_log<<Logger::Debug<<req.logprefix<<"Handling request \"" << req.url.path << "\"" << endl;
 
     YaHTTP::strstr_map_t::iterator header;
 
@@ -223,33 +224,34 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
 
     YaHTTP::THandlerFunction handler;
     if (!YaHTTP::Router::Route(&req, handler)) {
-      g_log<<Logger::Debug<<"HTTP: No route found for \"" << req.url.path << "\"" << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"No route found for \"" << req.url.path << "\"" << endl;
       throw HttpNotFoundException();
     }
 
     try {
       handler(&req, &resp);
-      g_log<<Logger::Debug<<"HTTP: Result for \"" << req.url.path << "\": " << resp.status << ", body length: " << resp.body.size() << endl;
+      g_log<<Logger::Debug<<req.logprefix<<"Result for \"" << req.url.path << "\": " << resp.status << ", body length: " << resp.body.size() << endl;
     }
     catch(HttpException&) {
       throw;
     }
     catch(PDNSException &e) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": Exception: " << e.reason << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": Exception: " << e.reason << endl;
       throw HttpInternalServerErrorException();
     }
     catch(std::exception &e) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": STL Exception: " << e.what() << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": STL Exception: " << e.what() << endl;
       throw HttpInternalServerErrorException();
     }
     catch(...) {
-      g_log<<Logger::Error<<"HTTP ISE for \""<< req.url.path << "\": Unknown Exception" << endl;
+      g_log<<Logger::Error<<req.logprefix<<"HTTP ISE for \""<< req.url.path << "\": Unknown Exception" << endl;
       throw HttpInternalServerErrorException();
     }
   }
   catch(HttpException &e) {
     resp = e.response();
-    g_log<<Logger::Debug<<"HTTP: Error result for \"" << req.url.path << "\": " << resp.status << endl;
+    // TODO rm this logline?
+    g_log<<Logger::Debug<<req.logprefix<<"Error result for \"" << req.url.path << "\": " << resp.status << endl;
     string what = YaHTTP::Utility::status2text(resp.status);
     if(req.accept_html) {
       resp.headers["Content-Type"] = "text/html; charset=utf-8";
@@ -276,49 +278,129 @@ void WebServer::handleRequest(HttpRequest& req, HttpResponse& resp) const
   }
 }
 
-void WebServer::serveConnection(std::shared_ptr<Socket> client) const
-try {
-  HttpRequest req;
-  YaHTTP::AsyncRequestLoader yarl;
-  yarl.initialize(&req);
-  int timeout = 5;
-  client->setNonBlocking();
+void WebServer::logRequest(const HttpRequest& req, const ComboAddress& remote) const {
+  if (d_loglevel >= WebServer::LogLevel::Detailed) {
+    auto logprefix = req.logprefix;
+    g_log<<Logger::Notice<<logprefix<<"Request details:"<<endl;
 
-  try {
-    while(!req.complete) {
-      int bytes;
-      char buf[1024];
-      bytes = client->readWithTimeout(buf, sizeof(buf), timeout);
-      if (bytes > 0) {
-        string data = string(buf, bytes);
-        req.complete = yarl.feed(data);
-      } else {
-        // read error OR EOF
-        break;
+    bool first = true;
+    for (const auto& r : req.getvars) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" GET params:"<<endl;
       }
+      g_log<<Logger::Notice<<logprefix<<"  "<<r.first<<": "<<r.second<<endl;
     }
-    yarl.finalize();
-  } catch (YaHTTP::ParseError &e) {
-    // request stays incomplete
-  }
 
-  HttpResponse resp;
-  WebServer::handleRequest(req, resp);
-  ostringstream ss;
-  resp.write(ss);
-  string reply = ss.str();
+    first = true;
+    for (const auto& r : req.postvars) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" POST params:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<r.first<<": "<<r.second<<endl;
+    }
 
-  client->writenWithTimeout(reply.c_str(), reply.size(), timeout);
-}
-catch(PDNSException &e) {
-  g_log<<Logger::Error<<"HTTP Exception: "<<e.reason<<endl;
+    first = true;
+    for (const auto& h : req.headers) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" Headers:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<h.first<<": "<<h.second<<endl;
+    }
+
+    if (req.body.empty()) {
+      g_log<<Logger::Notice<<logprefix<<" No body"<<endl;
+    } else {
+      g_log<<Logger::Notice<<logprefix<<" Full body: "<<endl;
+      g_log<<Logger::Notice<<logprefix<<"  "<<req.body<<endl;
+    }
+  }
 }
-catch(std::exception &e) {
-  if(strstr(e.what(), "timeout")==0)
-    g_log<<Logger::Error<<"HTTP STL Exception: "<<e.what()<<endl;
+
+void WebServer::logResponse(const HttpResponse& resp, const ComboAddress& remote, const string& logprefix) const {
+  if (d_loglevel >= WebServer::LogLevel::Detailed) {
+    g_log<<Logger::Notice<<logprefix<<"Response details:"<<endl;
+    bool first = true;
+    for (const auto& h : resp.headers) {
+      if (first) {
+        first = false;
+        g_log<<Logger::Notice<<logprefix<<" Headers:"<<endl;
+      }
+      g_log<<Logger::Notice<<logprefix<<"  "<<h.first<<": "<<h.second<<endl;
+    }
+    if (resp.body.empty()) {
+      g_log<<Logger::Notice<<logprefix<<" No body"<<endl;
+    } else {
+      g_log<<Logger::Notice<<logprefix<<" Full body: "<<endl;
+      g_log<<Logger::Notice<<logprefix<<"  "<<resp.body<<endl;
+    }
+  }
 }
-catch(...) {
-  g_log<<Logger::Error<<"HTTP: Unknown exception"<<endl;
+
+void WebServer::serveConnection(std::shared_ptr<Socket> client) const {
+  const string logprefix = d_logprefix + to_string(getUniqueID()) + " ";
+
+  HttpRequest req(logprefix);
+  HttpResponse resp;
+  ComboAddress remote;
+  string reply;
+
+  try {
+    YaHTTP::AsyncRequestLoader yarl;
+    yarl.initialize(&req);
+    int timeout = 5;
+    client->setNonBlocking();
+
+    try {
+      while(!req.complete) {
+        int bytes;
+        char buf[1024];
+        bytes = client->readWithTimeout(buf, sizeof(buf), timeout);
+        if (bytes > 0) {
+          string data = string(buf, bytes);
+          req.complete = yarl.feed(data);
+        } else {
+          // read error OR EOF
+          break;
+        }
+      }
+      yarl.finalize();
+    } catch (YaHTTP::ParseError &e) {
+      // request stays incomplete
+      g_log<<Logger::Warning<<logprefix<<"Unable to parse request: "<<e.what()<<endl;
+    }
+
+    if (d_loglevel >= WebServer::LogLevel::None) {
+      client->getRemote(remote);
+    }
+
+    logRequest(req, remote);
+
+    WebServer::handleRequest(req, resp);
+    ostringstream ss;
+    resp.write(ss);
+    reply = ss.str();
+
+    logResponse(resp, remote, logprefix);
+
+    client->writenWithTimeout(reply.c_str(), reply.size(), timeout);
+  }
+  catch(PDNSException &e) {
+    g_log<<Logger::Error<<logprefix<<"HTTP Exception: "<<e.reason<<endl;
+  }
+  catch(std::exception &e) {
+    if(strstr(e.what(), "timeout")==0)
+      g_log<<Logger::Error<<logprefix<<"HTTP STL Exception: "<<e.what()<<endl;
+  }
+  catch(...) {
+    g_log<<Logger::Error<<logprefix<<"Unknown exception"<<endl;
+  }
+
+  if (d_loglevel >= WebServer::LogLevel::Normal) {
+    g_log<<Logger::Notice<<logprefix<<remote<<" \""<<req.method<<" "<<req.url.path<<" HTTP/"<<req.versionStr(req.version)<<"\" "<<resp.status<<" "<<reply.size()<<endl;
+  }
 }
 
 WebServer::WebServer(const string &listenaddress, int port) :
@@ -332,10 +414,10 @@ void WebServer::bind()
 {
   try {
     d_server = createServer();
-    g_log<<Logger::Warning<<"Listening for HTTP requests on "<<d_server->d_local.toStringWithPort()<<endl;
+    g_log<<Logger::Warning<<d_logprefix<<"Listening for HTTP requests on "<<d_server->d_local.toStringWithPort()<<endl;
   }
   catch(NetworkError &e) {
-    g_log<<Logger::Error<<"Listening on HTTP socket failed: "<<e.what()<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"Listening on HTTP socket failed: "<<e.what()<<endl;
     d_server = nullptr;
   }
 }
@@ -357,28 +439,28 @@ void WebServer::go()
         } else {
           ComboAddress remote;
           if (client->getRemote(remote))
-            g_log<<Logger::Error<<"Webserver closing socket: remote ("<< remote.toString() <<") does not match the set ACL("<<d_acl.toString()<<")"<<endl;
+            g_log<<Logger::Error<<d_logprefix<<"Webserver closing socket: remote ("<< remote.toString() <<") does not match the set ACL("<<d_acl.toString()<<")"<<endl;
         }
       }
       catch(PDNSException &e) {
-        g_log<<Logger::Error<<"PDNSException while accepting a connection in main webserver thread: "<<e.reason<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"PDNSException while accepting a connection in main webserver thread: "<<e.reason<<endl;
       }
       catch(std::exception &e) {
-        g_log<<Logger::Error<<"STL Exception while accepting a connection in main webserver thread: "<<e.what()<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"STL Exception while accepting a connection in main webserver thread: "<<e.what()<<endl;
       }
       catch(...) {
-        g_log<<Logger::Error<<"Unknown exception while accepting a connection in main webserver thread"<<endl;
+        g_log<<Logger::Error<<d_logprefix<<"Unknown exception while accepting a connection in main webserver thread"<<endl;
       }
     }
   }
   catch(PDNSException &e) {
-    g_log<<Logger::Error<<"PDNSException in main webserver thread: "<<e.reason<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"PDNSException in main webserver thread: "<<e.reason<<endl;
   }
   catch(std::exception &e) {
-    g_log<<Logger::Error<<"STL Exception in main webserver thread: "<<e.what()<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"STL Exception in main webserver thread: "<<e.what()<<endl;
   }
   catch(...) {
-    g_log<<Logger::Error<<"Unknown exception in main webserver thread"<<endl;
+    g_log<<Logger::Error<<d_logprefix<<"Unknown exception in main webserver thread"<<endl;
   }
   _exit(1);
 }
index 5fa351a667e7acfa44064d4662a4e84a4bc16302..500f157044fb358e0607b5b067ae3c6343a8d214 100644 (file)
@@ -34,11 +34,12 @@ class WebServer;
 
 class HttpRequest : public YaHTTP::Request {
 public:
-  HttpRequest() : YaHTTP::Request(), accept_json(false), accept_html(false), complete(false) { };
+  HttpRequest(const string& logprefix="") : YaHTTP::Request(), accept_json(false), accept_html(false), complete(false), logprefix(logprefix) { };
 
   bool accept_json;
   bool accept_html;
   bool complete;
+  string logprefix;
   json11::Json json();
 
   // checks password _only_.
@@ -184,8 +185,43 @@ public:
   void registerApiHandler(const string& url, HandlerFunction handler);
   void registerWebHandler(const string& url, HandlerFunction handler);
 
+  enum class LogLevel : uint8_t {
+    None = 0,                // No logs from requests at all
+    Normal = 10,             // A "common log format"-like line e.g. '127.0.0.1 "GET /apache_pb.gif HTTP/1.0" 200 2326'
+    Detailed = 20,           // The full request headers and body, and the full response headers and body
+  };
+
+  void setLogLevel(const string& level) {
+    if (level == "none") {
+      d_loglevel = LogLevel::None;
+      return;
+    }
+
+    if (level == "normal") {
+      d_loglevel = LogLevel::Normal;
+      return;
+    }
+
+    if (level == "detailed") {
+      d_loglevel = LogLevel::Detailed;
+      return;
+    }
+
+    throw PDNSException("Unknown webserver log level: " + level);
+  }
+
+  void setLogLevel(const LogLevel level) {
+    d_loglevel = level;
+  };
+
+  LogLevel getLogLevel() {
+    return d_loglevel;
+  };
+
 protected:
   void registerBareHandler(const string& url, HandlerFunction handler);
+  void logRequest(const HttpRequest& req, const ComboAddress& remote) const;
+  void logResponse(const HttpResponse& resp, const ComboAddress& remote, const string& logprefix) const;
 
   virtual std::shared_ptr<Server> createServer() {
     return std::make_shared<Server>(d_listenaddress, d_port);
@@ -203,6 +239,11 @@ protected:
   bool d_registerWebHandlerCalled{false};
 
   NetmaskGroup d_acl;
+
+  const string d_logprefix = "[webserver] ";
+
+  // Describes the amount of logging the webserver does
+  WebServer::LogLevel d_loglevel{WebServer::LogLevel::Detailed};
 };
 
 #endif /* WEBSERVER_HH */
index 6106d25f3329d76107fc6134f991ee1347be0e22..ae62d7dd7f02b9381cd394b1b4c9b707864edad6 100644 (file)
@@ -159,15 +159,29 @@ void apiServerStatistics(HttpRequest* req, HttpResponse* resp) {
   if(req->method != "GET")
     throw HttpMethodNotAllowedException();
 
+  Json::array doc;
+  string name = req->getvars["statistic"];
+  if (!name.empty()) {
+    auto stat = productServerStatisticsFetch(name);
+    if (!stat) {
+      throw ApiException("Unknown statistic name");
+    }
+
+    doc.push_back(Json::object {
+      { "type", "StatisticItem" },
+      { "name", name },
+      { "value", std::to_string(*stat) },
+    });
+
+    resp->setBody(doc);
+
+    return;
+  }
+
   typedef map<string, string> stat_items_t;
   stat_items_t general_stats;
   productServerStatisticsFetch(general_stats);
 
-  auto resp_qtype_stats = g_rs.getQTypeResponseCounts();
-  auto resp_size_stats = g_rs.getSizeResponseCounts();
-  auto resp_rcode_stats = g_rs.getRCodeResponseCounts();
-
-  Json::array doc;
   for(const auto& item : general_stats) {
     doc.push_back(Json::object {
       { "type", "StatisticItem" },
@@ -176,6 +190,9 @@ void apiServerStatistics(HttpRequest* req, HttpResponse* resp) {
     });
   }
 
+  auto resp_qtype_stats = g_rs.getQTypeResponseCounts();
+  auto resp_size_stats = g_rs.getSizeResponseCounts();
+  auto resp_rcode_stats = g_rs.getRCodeResponseCounts();
   {
     Json::array values;
     for(const auto& item : resp_qtype_stats) {
index 0de6eda580b93760b2aaece1179bac32cc20d508..e280a20abed8f8645a14a25a65296acd80ea9201 100644 (file)
@@ -41,5 +41,6 @@ DNSName apiNameToDNSName(const string& name);
 
 // To be provided by product code.
 void productServerStatisticsFetch(std::map<string,string>& out);
+boost::optional<uint64_t> productServerStatisticsFetch(const std::string& name);
 
 #endif /* PDNS_WSAPI_HH */
index eb6a9707b723a2c41a708a1c86055352ac22d411..9dfa4babaced99f9926c521d73d2b4197c906609 100644 (file)
@@ -72,6 +72,7 @@ AuthWebServer::AuthWebServer() :
     d_ws = new WebServer(arg()["webserver-address"], arg().asNum("webserver-port"));
     d_ws->setApiKey(arg()["api-key"]);
     d_ws->setPassword(arg()["webserver-password"]);
+    d_ws->setLogLevel(arg()["webserver-loglevel"]);
 
     NetmaskGroup acl;
     acl.toMasks(::arg()["webserver-allow-from"]);
@@ -282,9 +283,10 @@ void AuthWebServer::indexfunction(HttpRequest* req, HttpResponse* resp)
 
   ret<<"Total queries: "<<S.read("udp-queries")<<". Question/answer latency: "<<S.read("latency")/1000.0<<"ms</p><br>"<<endl;
   if(req->getvars["ring"].empty()) {
-    vector<string>entries=S.listRings();
-    for(vector<string>::const_iterator i=entries.begin();i!=entries.end();++i)
-      printtable(ret,*i,S.getRingTitle(*i));
+    auto entries = S.listRings();
+    for(const auto &i: entries) {
+      printtable(ret, i, S.getRingTitle(i));
+    }
 
     printvars(ret);
     if(arg().mustDo("webserver-print-arguments"))
@@ -503,6 +505,17 @@ void productServerStatisticsFetch(map<string,string>& out)
   out["uptime"] = std::to_string(time(0) - s_starttime);
 }
 
+boost::optional<uint64_t> productServerStatisticsFetch(const std::string& name)
+{
+  try {
+    // ::read() calls ::exists() which throws a PDNSException when the key does not exist
+    return S.read(name);
+  }
+  catch(...) {
+    return boost::none;
+  }
+}
+
 static void validateGatheredRRType(const DNSResourceRecord& rr) {
   if (rr.qtype.getCode() == QType::OPT || rr.qtype.getCode() == QType::TSIG) {
     throw ApiException("RRset "+rr.qname.toString()+" IN "+rr.qtype.getName()+": invalid type given");
@@ -1696,6 +1709,11 @@ static void apiServerZoneDetail(HttpRequest* req, HttpResponse* resp) {
     if(!di.backend->deleteDomain(zonename))
       throw ApiException("Deleting domain '"+zonename.toString()+"' failed: backend delete failed/unsupported");
 
+    // clear caches
+    DNSSECKeeper dk(&B);
+    dk.clearCaches(zonename);
+    purgeAuthCaches(zonename.toString() + "$");
+
     // empty body on success
     resp->body = "";
     resp->status = 204; // No Content: declare that the zone is gone now
index 6d262656b5e04857920a6269454aac733be6538a..bf4038a2a51a92fbd10aa9489134dd855368e34e 100644 (file)
@@ -41,6 +41,7 @@
 #include "ext/incbin/incbin.h"
 #include "rec-lua-conf.hh"
 #include "rpzloader.hh"
+#include "uuid-utils.hh"
 
 extern thread_local FDMultiplexer* t_fdm;
 
@@ -48,10 +49,15 @@ using json11::Json;
 
 void productServerStatisticsFetch(map<string,string>& out)
 {
-  map<string,string> stats = getAllStatsMap();
+  map<string,string> stats = getAllStatsMap(StatComponent::API);
   out.swap(stats);
 }
 
+boost::optional<uint64_t> productServerStatisticsFetch(const std::string& name)
+{
+  return getStatByName(name);
+}
+
 static void apiWriteConfigFile(const string& filebasename, const string& content)
 {
   if (::arg()["api-config-dir"].empty()) {
@@ -452,6 +458,7 @@ RecursorWebServer::RecursorWebServer(FDMultiplexer* fdm)
   d_ws = new AsyncWebServer(fdm, arg()["webserver-address"], arg().asNum("webserver-port"));
   d_ws->setApiKey(arg()["api-key"]);
   d_ws->setPassword(arg()["webserver-password"]);
+  d_ws->setLogLevel(arg()["webserver-loglevel"]);
 
   NetmaskGroup acl;
   acl.toMasks(::arg()["webserver-allow-from"]);
@@ -620,50 +627,69 @@ void AsyncServer::newConnection()
 }
 
 // This is an entry point from FDM, so it needs to catch everything.
-void AsyncWebServer::serveConnection(std::shared_ptr<Socket> client) const
-try {
-  HttpRequest req;
-  YaHTTP::AsyncRequestLoader yarl;
-  yarl.initialize(&req);
-  client->setNonBlocking();
-
-  string data;
+void AsyncWebServer::serveConnection(std::shared_ptr<Socket> client) const {
+  const string logprefix = d_logprefix + to_string(getUniqueID()) + " ";
+
+  HttpRequest req(logprefix);
+  HttpResponse resp;
+  ComboAddress remote;
+  string reply;
+
   try {
-    while(!req.complete) {
-      int bytes = arecvtcp(data, 16384, client.get(), true);
-      if (bytes > 0) {
-        req.complete = yarl.feed(data);
-      } else {
-        // read error OR EOF
-        break;
+    YaHTTP::AsyncRequestLoader yarl;
+    yarl.initialize(&req);
+    client->setNonBlocking();
+
+    string data;
+    try {
+      while(!req.complete) {
+        int bytes = arecvtcp(data, 16384, client.get(), true);
+        if (bytes > 0) {
+          req.complete = yarl.feed(data);
+        } else {
+          // read error OR EOF
+          break;
+        }
       }
+      yarl.finalize();
+    } catch (YaHTTP::ParseError &e) {
+      // request stays incomplete
+      g_log<<Logger::Warning<<logprefix<<"Unable to parse request: "<<e.what()<<endl;
+    }
+
+    if (d_loglevel >= WebServer::LogLevel::None) {
+      client->getRemote(remote);
+    }
+
+    logRequest(req, remote);
+
+    WebServer::handleRequest(req, resp);
+    ostringstream ss;
+    resp.write(ss);
+    reply = ss.str();
+
+    logResponse(resp, remote, logprefix);
+
+    // now send the reply
+    if (asendtcp(reply, client.get()) == -1 || reply.empty()) {
+      g_log<<Logger::Error<<logprefix<<"Failed sending reply to HTTP client"<<endl;
     }
-    yarl.finalize();
-  } catch (YaHTTP::ParseError &e) {
-    // request stays incomplete
+  }
+  catch(PDNSException &e) {
+    g_log<<Logger::Error<<logprefix<<"Exception: "<<e.reason<<endl;
+  }
+  catch(std::exception &e) {
+    if(strstr(e.what(), "timeout")==0)
+      g_log<<Logger::Error<<logprefix<<"STL Exception: "<<e.what()<<endl;
+  }
+  catch(...) {
+    g_log<<Logger::Error<<logprefix<<"Unknown exception"<<endl;
   }
 
-  HttpResponse resp;
-  handleRequest(req, resp);
-  ostringstream ss;
-  resp.write(ss);
-  data = ss.str();
-
-  // now send the reply
-  if (asendtcp(data, client.get()) == -1 || data.empty()) {
-    g_log<<Logger::Error<<"Failed sending reply to HTTP client"<<endl;
+  if (d_loglevel >= WebServer::LogLevel::Normal) {
+    g_log<<Logger::Notice<<logprefix<<remote<<" \""<<req.method<<" "<<req.url.path<<" HTTP/"<<req.versionStr(req.version)<<"\" "<<resp.status<<" "<<reply.size()<<endl;
   }
 }
-catch(PDNSException &e) {
-  g_log<<Logger::Error<<"HTTP Exception: "<<e.reason<<endl;
-}
-catch(std::exception &e) {
-  if(strstr(e.what(), "timeout")==0)
-    g_log<<Logger::Error<<"HTTP STL Exception: "<<e.what()<<endl;
-}
-catch(...) {
-  g_log<<Logger::Error<<"HTTP: Unknown exception"<<endl;
-}
 
 void AsyncWebServer::go() {
   if (!d_server)
index 80835f999ed644c77096861e1aed81cefb2e7af2..bc24d08c5b8f5a04295b87546ab64e19000bf9c7 100644 (file)
@@ -57,3 +57,14 @@ class Servers(ApiTestCase):
             self.assertIn('60', [e['name'] for e in respsize_stats])
             self.assertIn('example.com/A', [e['name'] for e in queries_stats])
             self.assertIn('No Error', [e['name'] for e in rcode_stats])
+
+    def test_read_one_statistic(self):
+        r = self.session.get(self.url("/api/v1/servers/localhost/statistics?statistic=uptime"))
+        self.assert_success_json(r)
+        data = r.json()
+        self.assertIn('uptime', [e['name'] for e in data])
+
+    def test_read_one_non_existent_statistic(self):
+        r = self.session.get(self.url("/api/v1/servers/localhost/statistics?statistic=uptimeAAAA"))
+        self.assertEquals(r.status_code, 422)
+        self.assertIn("Unknown statistic name", r.json()['error'])
index e1df6cfa7df44a5631334af0776505a88560a587..70527759b3e4fcf1207f3f288e3e19a5c86cae22 100644 (file)
@@ -214,6 +214,7 @@ class DNSDistTest(unittest.TestCase):
         ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
             sock.bind(("127.0.0.1", port))
@@ -304,6 +305,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTCPConnection(cls, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -313,6 +315,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -385,6 +388,7 @@ class DNSDistTest(unittest.TestCase):
             for response in responses:
                 cls._toResponderQueue.put(response, True, timeout)
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -475,6 +479,7 @@ class DNSDistTest(unittest.TestCase):
         ourNonce = libnacl.utils.rand_nonce()
         theirNonce = None
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
index 6ad4ff3d674185a3e7ff34e263fb3e61ce69561a..7fbe25b04b7470821c56380ecbf9e670e62850c7 100644 (file)
@@ -2,7 +2,7 @@ dnspython>=1.11,<1.16.0
 nose>=1.3.7
 libnacl>=1.4.3
 requests>=2.1.0
-protobuf>=2.5,<3.0; sys_platform != 'darwin'
-protobuf>=3.0; sys_platform == 'darwin'
+protobuf>=2.5,<3.0; sys_platform != 'darwin' and sys_platform != 'openbsd6'
+protobuf>=3.0; sys_platform == 'darwin' or sys_platform == 'openbsd6'
 pysnmp>=4.3.4
 future>=0.17.1
index 00578a3e779b0fefcad922c0b99a2fcbbceb3cfb..3d751cfe8c1710860c5b63096a4a87b82c5d725d 100644 (file)
@@ -114,7 +114,7 @@ class TestAPIBasics(DNSDistTest):
             self.assertTrue(server['state'] in ['up', 'down', 'UP', 'DOWN'])
 
         for frontend in content['frontends']:
-            for key in ['id', 'address', 'udp', 'tcp', 'queries']:
+            for key in ['id', 'address', 'udp', 'tcp', 'type', 'queries']:
                 self.assertIn(key, frontend)
 
             for key in ['id', 'queries']:
@@ -226,6 +226,7 @@ class TestAPIBasics(DNSDistTest):
             values[entry['name']] = entry['value']
 
         expected = ['responses', 'servfail-responses', 'queries', 'acl-drops',
+                    'frontend-noerror', 'frontend-nxdomain', 'frontend-servfail',
                     'rule-drop', 'rule-nxdomain', 'rule-refused', 'self-answered', 'downstream-timeouts',
                     'downstream-send-errors', 'trunc-failures', 'no-policy', 'latency0-1',
                     'latency1-10', 'latency10-50', 'latency50-100', 'latency100-1000',
@@ -255,6 +256,7 @@ class TestAPIBasics(DNSDistTest):
         content = r.json()
 
         expected = ['responses', 'servfail-responses', 'queries', 'acl-drops',
+                    'frontend-noerror', 'frontend-nxdomain', 'frontend-servfail',
                     'rule-drop', 'rule-nxdomain', 'rule-refused', 'self-answered', 'downstream-timeouts',
                     'downstream-send-errors', 'trunc-failures', 'no-policy', 'latency0-1',
                     'latency1-10', 'latency10-50', 'latency50-100', 'latency100-1000',
index 4bb4c543f7fe43f9b92fe9ac08189d1e227609ce..6bc5004614ad49caf2460b98518f2ddc614063b0 100644 (file)
@@ -165,111 +165,111 @@ class TestAXFR(DNSDistTest):
         self.assertEqual(query, receivedQuery)
         self.assertEqual(len(receivedResponses), len(responses))
 
-    def testFourNoFirstSOAAXFR(self):
-        """
-        AXFR: Four messages, no SOA in the first one
-        """
-        name = 'fournosoainfirst.axfr.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AXFR', 'IN')
-        responses = []
-        soa = dns.rrset.from_text(name,
-                                  60,
-                                  dns.rdataclass.IN,
-                                  dns.rdatatype.SOA,
-                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
-        response = dns.message.make_response(query)
-        response.answer.append(dns.rrset.from_text(name,
-                                                   60,
-                                                   dns.rdataclass.IN,
-                                                   dns.rdatatype.A,
-                                                   '192.0.2.1'))
-        responses.append(response)
+    def testFourNoFirstSOAAXFR(self):
+        """
+        AXFR: Four messages, no SOA in the first one
+        """
+        name = 'fournosoainfirst.axfr.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AXFR', 'IN')
+        responses = []
+        soa = dns.rrset.from_text(name,
+                                  60,
+                                  dns.rdataclass.IN,
+                                  dns.rdatatype.SOA,
+                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
+        response = dns.message.make_response(query)
+        response.answer.append(dns.rrset.from_text(name,
+                                                   60,
+                                                   dns.rdataclass.IN,
+                                                   dns.rdatatype.A,
+                                                   '192.0.2.1'))
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(soa)
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(soa)
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text('dummy.' + name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text('dummy.' + name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.TXT,
-                                    'dummy')
-        response.answer.append(rrset)
-        response.answer.append(soa)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.TXT,
+                                    'dummy')
+        response.answer.append(rrset)
+        response.answer.append(soa)
+        responses.append(response)
 
-        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
-        receivedQuery.id = query.id
-        self.assertEqual(query, receivedQuery)
-        self.assertEqual(len(receivedResponses), 1)
+        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(len(receivedResponses), 1)
 
-    def testFourLastSOAInSecondAXFR(self):
-        """
-        AXFR: Four messages, SOA in the first one and the second one
-        """
-        name = 'foursecondsoainsecond.axfr.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AXFR', 'IN')
-        responses = []
-        soa = dns.rrset.from_text(name,
-                                  60,
-                                  dns.rdataclass.IN,
-                                  dns.rdatatype.SOA,
-                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
+    def testFourLastSOAInSecondAXFR(self):
+        """
+        AXFR: Four messages, SOA in the first one and the second one
+        """
+        name = 'foursecondsoainsecond.axfr.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AXFR', 'IN')
+        responses = []
+        soa = dns.rrset.from_text(name,
+                                  60,
+                                  dns.rdataclass.IN,
+                                  dns.rdatatype.SOA,
+                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
 
-        response = dns.message.make_response(query)
-        response.answer.append(soa)
-        response.answer.append(dns.rrset.from_text(name,
-                                                   60,
-                                                   dns.rdataclass.IN,
-                                                   dns.rdatatype.A,
-                                                   '192.0.2.1'))
-        responses.append(response)
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        response.answer.append(dns.rrset.from_text(name,
+                                                   60,
+                                                   dns.rdataclass.IN,
+                                                   dns.rdatatype.A,
+                                                   '192.0.2.1'))
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        response.answer.append(soa)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text('dummy.' + name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.AAAA,
-                                    '2001:DB8::1')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text('dummy.' + name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        response = dns.message.make_response(query)
-        rrset = dns.rrset.from_text(name,
-                                    60,
-                                    dns.rdataclass.IN,
-                                    dns.rdatatype.TXT,
-                                    'dummy')
-        response.answer.append(rrset)
-        responses.append(response)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.TXT,
+                                    'dummy')
+        response.answer.append(rrset)
+        responses.append(response)
 
-        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
-        receivedQuery.id = query.id
-        self.assertEqual(query, receivedQuery)
-        self.assertEqual(len(receivedResponses), 2)
+        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(len(receivedResponses), 2)
index ea943bb00cef825da4181b2c639a560202807037..78105512e88963f88d4a6a052003837d173f549f 100644 (file)
@@ -34,19 +34,14 @@ class TestAdvancedAllow(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedAllowDropped(self):
         """
@@ -56,11 +51,10 @@ class TestAdvancedAllow(DNSDistTest):
         """
         name = 'notallowed.advanced.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
 
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
 
 class TestAdvancedFixupCase(DNSDistTest):
 
@@ -91,20 +85,14 @@ class TestAdvancedFixupCase(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestAdvancedRemoveRD(DNSDistTest):
 
@@ -133,19 +121,14 @@ class TestAdvancedRemoveRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepRD(self):
         """
@@ -165,20 +148,14 @@ class TestAdvancedRemoveRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedAddCD(DNSDistTest):
 
@@ -208,19 +185,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedSetCDViaAction(self):
         """
@@ -242,19 +214,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepNoCD(self):
         """
@@ -274,19 +241,14 @@ class TestAdvancedAddCD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedClearRD(DNSDistTest):
 
@@ -316,19 +278,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedClearRDViaAction(self):
         """
@@ -350,19 +307,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAdvancedKeepRD(self):
         """
@@ -382,19 +334,14 @@ class TestAdvancedClearRD(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 
 class TestAdvancedACL(DNSDistTest):
@@ -415,11 +362,10 @@ class TestAdvancedACL(DNSDistTest):
         name = 'tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestAdvancedDelay(DNSDistTest):
 
@@ -528,19 +474,14 @@ class TestAdvancedAndNot(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testAOverUDPReturnsNotImplemented(self):
         """
@@ -628,11 +569,10 @@ class TestAdvancedOr(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 
 class TestAdvancedLogAction(DNSDistTest):
@@ -690,24 +630,19 @@ class TestAdvancedDNSSEC(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
-        (_, receivedResponse) = self.sendUDPQuery(doquery, response)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(doquery, response)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(doquery, response)
+            self.assertEquals(receivedResponse, None)
 
 class TestAdvancedQClass(DNSDistTest):
 
@@ -723,10 +658,10 @@ class TestAdvancedQClass(DNSDistTest):
         name = 'qclasschaos.advanced.tests.powerdns.com.'
         query = dns.message.make_query(name, 'TXT', 'CHAOS')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testAdvancedQClassINAllow(self):
         """
@@ -743,19 +678,14 @@ class TestAdvancedQClass(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedOpcode(DNSDistTest):
 
@@ -772,10 +702,10 @@ class TestAdvancedOpcode(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         query.set_opcode(dns.opcode.NOTIFY)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testAdvancedOpcodeUpdateINAllow(self):
         """
@@ -793,19 +723,14 @@ class TestAdvancedOpcode(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedNonTerminalRule(DNSDistTest):
 
@@ -835,19 +760,14 @@ class TestAdvancedNonTerminalRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEquals(expectedQuery, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEquals(expectedQuery, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedStringOnlyServer(DNSDistTest):
 
@@ -869,19 +789,14 @@ class TestAdvancedStringOnlyServer(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedRestoreFlagsOnSelfResponse(DNSDistTest):
 
@@ -912,13 +827,11 @@ class TestAdvancedRestoreFlagsOnSelfResponse(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(response, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(response, receivedResponse)
 
 class TestAdvancedQPS(DNSDistTest):
 
@@ -989,11 +902,10 @@ class TestAdvancedQPSNone(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedNMGRule(DNSDistTest):
 
@@ -1016,11 +928,10 @@ class TestAdvancedNMGRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestDSTPortRule(DNSDistTest):
 
@@ -1043,11 +954,10 @@ class TestDSTPortRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedLabelsCountRule(DNSDistTest):
 
@@ -1071,19 +981,14 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # more than 6 labels, the query should be refused
         name = 'not.ok.labelscount.advanced.tests.powerdns.com.'
@@ -1091,11 +996,10 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
         # less than 5 labels, the query should be refused
         name = 'labelscountadvanced.tests.powerdns.com.'
@@ -1103,11 +1007,10 @@ class TestAdvancedLabelsCountRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedWireLengthRule(DNSDistTest):
 
@@ -1130,19 +1033,14 @@ class TestAdvancedWireLengthRule(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # too short, the query should be refused
         name = 'short.qnamewirelength.advanced.tests.powerdns.com.'
@@ -1150,11 +1048,10 @@ class TestAdvancedWireLengthRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
         # too long, the query should be refused
         name = 'toolongtobevalid.qnamewirelength.advanced.tests.powerdns.com.'
@@ -1162,11 +1059,10 @@ class TestAdvancedWireLengthRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedIncludeDir(DNSDistTest):
 
@@ -1190,19 +1086,14 @@ class TestAdvancedIncludeDir(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
         # this one should be refused
         name = 'notincludedir.advanced.tests.powerdns.com.'
@@ -1210,11 +1101,10 @@ class TestAdvancedIncludeDir(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestAdvancedLuaDO(DNSDistTest):
 
@@ -1247,30 +1137,22 @@ class TestAdvancedLuaDO(DNSDistTest):
         doResponse.set_rcode(dns.rcode.NXDOMAIN)
 
         # without DO
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # with DO
-        (_, receivedResponse) = self.sendUDPQuery(queryWithDO, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        doResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, doResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(queryWithDO, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        doResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, doResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(queryWithDO, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            doResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, doResponse)
 
 class TestAdvancedLuaRefused(DNSDistTest):
 
@@ -1298,15 +1180,12 @@ class TestAdvancedLuaRefused(DNSDistTest):
         refusedResponse = dns.message.make_response(query)
         refusedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            refusedResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, refusedResponse)
 
 class TestAdvancedLuaActionReturnSyntax(DNSDistTest):
 
@@ -1334,15 +1213,12 @@ class TestAdvancedLuaActionReturnSyntax(DNSDistTest):
         refusedResponse = dns.message.make_response(query)
         refusedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        refusedResponse.id = receivedResponse.id
-        self.assertEquals(receivedResponse, refusedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            refusedResponse.id = receivedResponse.id
+            self.assertEquals(receivedResponse, refusedResponse)
 
 class TestAdvancedLuaTruncated(DNSDistTest):
 
@@ -1488,11 +1364,10 @@ class TestAdvancedRD(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAdvancedNoRDAllowed(self):
         """
@@ -1503,15 +1378,12 @@ class TestAdvancedRD(DNSDistTest):
         query.flags &= ~dns.flags.RD
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalPort(DNSDistTest):
 
@@ -1541,11 +1413,10 @@ class TestAdvancedGetLocalPort(DNSDistTest):
                                     'port-was-{}.local-port.advanced.tests.powerdns.com.'.format(self._dnsDistPort))
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalPortOnAnyBind(DNSDistTest):
 
@@ -1576,11 +1447,10 @@ class TestAdvancedGetLocalPortOnAnyBind(DNSDistTest):
                                     'port-was-{}.local-port-any.advanced.tests.powerdns.com.'.format(self._dnsDistPort))
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
 
@@ -1614,11 +1484,10 @@ class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
                                     'address-was-127-0-0-1.local-address-any.advanced.tests.powerdns.com.')
         response.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedLuaTempFailureTTL(DNSDistTest):
 
@@ -1655,12 +1524,14 @@ class TestAdvancedLuaTempFailureTTL(DNSDistTest):
                                     '::1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedEDNSOptionRule(DNSDistTest):
 
@@ -1679,10 +1550,10 @@ class TestAdvancedEDNSOptionRule(DNSDistTest):
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None)
+            self.assertEquals(receivedResponse, None)
 
     def testReplied(self):
         """
@@ -1695,25 +1566,29 @@ class TestAdvancedEDNSOptionRule(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[], payload=512)
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
 
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # and with no EDNS at all
         query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
 
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
 class TestAdvancedAllowHeaderOnly(DNSDistTest):
 
@@ -1775,7 +1650,7 @@ class TestAdvancedAllowHeaderOnly(DNSDistTest):
             self.assertEquals(query, receivedQuery)
             self.assertEquals(receivedResponse, response)
 
-class TestAdvancedEDNSVersionnRule(DNSDistTest):
+class TestAdvancedEDNSVersionRule(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}
index ea4b7b08b9a94d81d4e828cd60fd83fd3c0915cc..4ecdb096e0278093a4847068c4a149c82b8ab373 100644 (file)
@@ -30,11 +30,10 @@ class TestBasics(DNSDistTest):
         """
         name = 'drop.test.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
     def testAWithECS(self):
         """
@@ -52,15 +51,12 @@ class TestBasics(DNSDistTest):
 
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testSimpleA(self):
         """
@@ -76,19 +72,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testAnyIsTruncated(self):
         """
@@ -214,11 +205,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testQNameReturnsSpoofed(self):
         """
@@ -238,12 +228,10 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.4')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainAndQTypeReturnsNotImplemented(self):
         """
@@ -259,11 +247,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NOTIMP)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testDomainWithoutQTypeIsNotAffected(self):
         """
@@ -284,19 +271,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testOtherDomainANDQTypeIsNotAffected(self):
         """
@@ -317,19 +299,14 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testWrongResponse(self):
         """
@@ -353,17 +330,13 @@ class TestBasics(DNSDistTest):
                                     'nothing to see here')
         unrelatedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, unrelatedResponse)
-        self.assertTrue(receivedQuery)
-        self.assertEquals(receivedResponse, None)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, unrelatedResponse)
+            self.assertTrue(receivedQuery)
+            self.assertEquals(receivedResponse, None)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
 
     def testHeaderOnlyRefused(self):
         """
@@ -375,17 +348,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.REFUSED)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
     def testHeaderOnlyNoErrorResponse(self):
         """
@@ -396,17 +365,13 @@ class TestBasics(DNSDistTest):
         response = dns.message.make_response(query)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testHeaderOnlyNXDResponse(self):
         """
@@ -418,17 +383,13 @@ class TestBasics(DNSDistTest):
         response.set_rcode(dns.rcode.NXDOMAIN)
         response.question = []
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testAddActionDNSName(self):
         """
@@ -439,8 +400,10 @@ class TestBasics(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testAddActionDNSNames(self):
         """
@@ -451,9 +414,8 @@ class TestBasics(DNSDistTest):
             expectedResponse = dns.message.make_response(query)
             expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-            self.assertEquals(receivedResponse, expectedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (_, receivedResponse) = sender(query, response=None, useQueue=False)
+                self.assertEquals(receivedResponse, expectedResponse)
 
-if __name__ == '__main__':
-    unittest.main()
-    exit(0)
index d50aa88694c2316745a5e69651e25974114d6404..aff1d2b6ee60069aed9b336522351f4125e26553 100644 (file)
@@ -155,19 +155,14 @@ class TestCaching(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
-
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                self.assertTrue(receivedQuery)
+                self.assertTrue(receivedResponse)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(receivedResponse, response)
 
         for key in self._responsesCounter:
             value = self._responsesCounter[key]
@@ -192,19 +187,14 @@ class TestCaching(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
-
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            self.assertTrue(receivedQuery)
-            self.assertTrue(receivedResponse)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, response)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                self.assertTrue(receivedQuery)
+                self.assertTrue(receivedResponse)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(receivedResponse, response)
 
         for key in self._responsesCounter:
             value = self._responsesCounter[key]
@@ -609,7 +599,7 @@ class TestCachingNoStale(DNSDistTest):
         Cache: Cache entry, set backend down, we should not get a stale entry
 
         """
-        ttl = 1
+        ttl = 2
         name = 'nostale.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -663,7 +653,7 @@ class TestCachingStale(DNSDistTest):
 
         """
         misses = 0
-        ttl = 1
+        ttl = 2
         name = 'stale.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -729,7 +719,7 @@ class TestCachingStaleExpunged(DNSDistTest):
         """
         misses = 0
         drops = 0
-        ttl = 1
+        ttl = 2
         name = 'stale-but-expunged.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -803,7 +793,7 @@ class TestCachingStaleExpungePrevented(DNSDistTest):
         Cache: Cache entry, set backend down, wait for the cache cleaning to run and remove the entry, still get a cache HIT because the stale entry was not removed
         """
         misses = 0
-        ttl = 1
+        ttl = 2
         name = 'stale-not-expunged.cache.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
@@ -1578,7 +1568,7 @@ class TestCachingNegativeTTL(DNSDistTest):
 
         time.sleep(self._negCacheTTL + 1)
 
-        # we should not have cached for longer than the negativel TTL
+        # we should not have cached for longer than the negative TTL
         # so it should be a miss
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         self.assertTrue(receivedQuery)
@@ -1730,38 +1720,30 @@ class TestCachingECSWithoutPoolECS(DNSDistTest):
         response.answer.append(rrset)
 
         # first query to fill the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # next queries should hit the cache
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # over TCP too
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
         # we mark the backend as down
         self.sendConsoleCommand("getServer(0):setDown()")
 
         # we should NOT get a cached entry since it has ECS and we haven't asked the pool
         # to add ECS when no backend is up
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        # same over TCP
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestCachingECSWithPoolECS(DNSDistTest):
 
@@ -1794,38 +1776,30 @@ class TestCachingECSWithPoolECS(DNSDistTest):
         response.answer.append(rrset)
 
         # first query to fill the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, response)
 
         # next queries should hit the cache
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # over TCP too
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
         # we mark the backend as down
         self.sendConsoleCommand("getServer(0):setDown()")
 
         # we should STILL get a cached entry since it has ECS and we have asked the pool
         # to add ECS when no backend is up
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
-
-        # same over TCP
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, response)
 
 class TestCachingCollisionNoECSParsing(DNSDistTest):
 
index 1971ddd9e1dd6e4015fa10d26bdd310c0fff1229..f075f4859e2dfc7f8c0524d69bf78a74e974a2db 100644 (file)
@@ -137,7 +137,7 @@ class TestCarbon(DNSDistTest):
         for line in data1.splitlines():
             if expectedStart in line:
                 parts = line.split(b' ')
-                if 'servers-up' in line:
+                if b'servers-up' in line:
                     self.assertEquals(len(parts), 3)
                     self.assertTrue(parts[1].isdigit())
                     self.assertEquals(int(parts[1]), 2)
@@ -160,7 +160,7 @@ class TestCarbon(DNSDistTest):
         for line in data2.splitlines():
             if expectedStart in line:
                 parts = line.split(b' ')
-                if 'servers-up' in line:
+                if b'servers-up' in line:
                     self.assertEquals(len(parts), 3)
                     self.assertTrue(parts[1].isdigit())
                     self.assertEquals(int(parts[1]), 2)
index 4a4259420decf6b2419cc3fe9b8c69e50b4dc60b..a5768b5fb835ddc96c2630296db1541497d4bedf 100644 (file)
@@ -4,7 +4,11 @@ import json
 import requests
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+try:
+  range = xrange
+except NameError:
+  pass
 
 class DynBlocksTest(DNSDistTest):
 
index f5838bf40ce1efc50a58ff6352dcd06f7c76b372..a655fd5ba7f227c2d5271d96764c98a9ee875c2f 100644 (file)
@@ -23,17 +23,17 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.COOKIE]:count() ~= 2 then
           return DNSAction.Spoof, "192.0.2.2"
         end
-        if options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.3"
         end
-        if options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:getValues()[2]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.4"
         end
       elseif string.match(qname, 'cookie') then
         if options[EDNSOptionCode.COOKIE] == nil then
           return DNSAction.Spoof, "192.0.2.1"
         end
-        if options[EDNSOptionCode.COOKIE]:count() ~= 1 or options[EDNSOptionCode.COOKIE]:getValues()[0]:len() ~= 16 then
+        if options[EDNSOptionCode.COOKIE]:count() ~= 1 or options[EDNSOptionCode.COOKIE]:getValues()[1]:len() ~= 16 then
           return DNSAction.Spoof, "192.0.2.2"
         end
       end
@@ -42,7 +42,7 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.ECS] == nil then
           return DNSAction.Spoof, "192.0.2.51"
         end
-        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 8 then
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[1]:len() ~= 8 then
           return DNSAction.Spoof, "192.0.2.52"
         end
       end
@@ -51,7 +51,7 @@ class EDNSOptionsBase(DNSDistTest):
         if options[EDNSOptionCode.ECS] == nil then
           return DNSAction.Spoof, "192.0.2.101"
         end
-        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[0]:len() ~= 20 then
+        if options[EDNSOptionCode.ECS]:count() ~= 1 or options[EDNSOptionCode.ECS]:getValues()[1]:len() ~= 20 then
           return DNSAction.Spoof, "192.0.2.102"
         end
       end
@@ -66,7 +66,7 @@ class TestEDNSOptions(EDNSOptionsBase):
     _config_template = """
     %s
 
-    addLuaAction(AllRule(), testEDNSOptions)
+    addAction(AllRule(), LuaAction(testEDNSOptions))
 
     newServer{address="127.0.0.1:%s"}
     """
@@ -86,26 +86,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '192.0.2.255')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testCookie(self):
         """
         EDNS Options: Cookie
         """
         name = 'cookie.ednsoptions.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -115,19 +110,14 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS4(self):
         """
@@ -144,19 +134,14 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS6(self):
         """
@@ -173,26 +158,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testECS6Cookie(self):
         """
         EDNS Options: Cookie + ECS6
         """
         name = 'cookie-ecs6.ednsoptions.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
         response = dns.message.make_response(query)
@@ -203,28 +183,23 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
     def testMultiCookiesECS6(self):
         """
         EDNS Options: Two Cookies + ECS6
         """
         name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
-        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
-        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -234,26 +209,21 @@ class TestEDNSOptions(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
 
 class TestEDNSOptionsAddingECS(EDNSOptionsBase):
 
     _config_template = """
     %s
 
-    addLuaAction(AllRule(), testEDNSOptions)
+    addAction(AllRule(), LuaAction(testEDNSOptions))
 
     newServer{address="127.0.0.1:%s", useClientSubnet=true}
     """
@@ -275,26 +245,21 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
     def testCookie(self):
         """
         EDNS Options: Cookie (adding ECS)
         """
         name = 'cookie.ednsoptions-ecs.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco])
         ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[eco,ecso], payload=512)
@@ -306,19 +271,14 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery, 1)
+            self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
     def testECS4(self):
         """
@@ -337,19 +297,14 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testECS6(self):
         """
@@ -368,26 +323,21 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testECS6Cookie(self):
         """
         EDNS Options: Cookie + ECS6 (adding ECS)
         """
         name = 'cookie-ecs6.ednsoptions-ecs.tests.powerdns.com.'
-        eco = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[ecso,eco])
         ecsoResponse = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128, scope=56)
@@ -400,28 +350,23 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery, 1)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery, 1)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testMultiCookiesECS6(self):
         """
         EDNS Options: Two Cookies + ECS6
         """
         name = 'multiplecookies-ecs6.ednsoptions.tests.powerdns.com.'
-        eco1 = cookiesoption.CookiesOption('deadbeef', 'deadbeef')
+        eco1 = cookiesoption.CookiesOption(b'deadbeef', b'deadbeef')
         ecso = clientsubnetoption.ClientSubnetOption('2001:DB8::1', 128)
-        eco2 = cookiesoption.CookiesOption('deadc0de', 'deadc0de')
+        eco2 = cookiesoption.CookiesOption(b'deadc0de', b'deadc0de')
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, options=[eco1, ecso, eco2])
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -431,16 +376,11 @@ class TestEDNSOptionsAddingECS(EDNSOptionsBase):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(receivedQuery, query)
-        self.assertEquals(receivedResponse, response)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
index 7704641319a7e08531bb775f5896d105b50dd05b..c1d7c56f28ef69d9112be53ad38b782764395319 100644 (file)
@@ -18,7 +18,7 @@ class TestEDNSSelfGenerated(DNSDistTest):
       return DNSAction.Nxdomain, ""
     end
 
-    addLuaAction("lua.edns-self.tests.powerdns.com.", luarule)
+    addAction("lua.edns-self.tests.powerdns.com.", LuaAction(luarule))
 
     addAction("spoof.edns-self.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
 
@@ -36,33 +36,30 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'no-edns.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
@@ -75,11 +72,10 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoDO(self):
         """
@@ -90,45 +86,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-no-do.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
@@ -141,15 +128,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertFalse(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
     def testWithEDNSWithDO(self):
         """
@@ -160,45 +144,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-do.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=True)
@@ -211,15 +186,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
     def testWithEDNSNoOptions(self):
         """
@@ -231,45 +203,36 @@ class TestEDNSSelfGenerated(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.tc.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.lua.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
         name = 'edns-options.spoof.edns-self.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, options=[ecso], payload=512, want_dnssec=True)
@@ -282,15 +245,12 @@ class TestEDNSSelfGenerated(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
-        self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
-        self.assertEquals(receivedResponse.payload, 1042)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
+            self.assertTrue(receivedResponse.ednsflags & dns.flags.DO)
+            self.assertEquals(receivedResponse.payload, 1042)
 
 
 class TestEDNSSelfGeneratedDisabled(DNSDistTest):
@@ -309,7 +269,7 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
       return DNSAction.Nxdomain, ""
     end
 
-    addLuaAction("lua.edns-self-disabled.tests.powerdns.com.", luarule)
+    addAction("lua.edns-self-disabled.tests.powerdns.com.", LuaAction(luarule))
 
     addAction("spoof.edns-self-disabled.tests.powerdns.com.", SpoofAction({'192.0.2.1', '192.0.2.2'}))
 
@@ -327,33 +287,30 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.tc.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.flags |= dns.flags.TC
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.lua.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.NXDOMAIN)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
 
         name = 'edns-no-do.spoof.edns-self-disabled.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096, want_dnssec=False)
@@ -366,8 +323,7 @@ class TestEDNSSelfGeneratedDisabled(DNSDistTest):
                                                            dns.rdatatype.A,
                                                            '192.0.2.1', '192.0.2.2'))
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
index 6adb863ed012c16a06ec8a2cfcdf4ad1945652f5..87acf47da254373bff73743a83ef0bba60036d46 100644 (file)
@@ -40,19 +40,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -77,19 +72,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSECS(self):
         """
@@ -113,19 +103,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
 
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithoutECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithoutECS(response, receivedResponse)
 
     def testWithoutEDNSResponseWithECS(self):
         """
@@ -154,19 +139,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithECS(self):
         """
@@ -195,19 +175,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECSResponseWithCookiesThenECS(self):
         """
@@ -238,19 +213,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         expectedResponse.answer.append(rrset)
         expectedResponse.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithECSThenCookies(self):
         """
@@ -281,19 +251,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         expectedResponse.answer.append(rrset)
         response.use_edns(edns=True, payload=4096, options=[ecoResponse])
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=1)
 
     def testWithEDNSNoECSResponseWithCookiesThenECSThenCookies(self):
         """
@@ -323,20 +288,14 @@ class TestEdnsClientSubnetNoOverride(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse, withCookies=2)
 
 class TestEdnsClientSubnetOverride(DNSDistTest):
     """
@@ -376,19 +335,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithEDNSNoECS(self):
         """
@@ -414,19 +368,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithoutECS(expectedResponse, receivedResponse)
 
     def testWithEDNSShorterInitialECS(self):
         """
@@ -454,19 +403,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSLongerInitialECS(self):
         """
@@ -494,19 +438,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithEDNSSameSizeInitialECS(self):
         """
@@ -534,19 +473,14 @@ class TestEdnsClientSubnetOverride(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
 class TestECSDisabledByRuleOrLua(DNSDistTest):
     """
@@ -586,19 +520,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
         response.answer.append(rrset)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSDisabledViaRule(self):
         """
@@ -614,19 +543,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryNoEDNS(query, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
     def testWithECSDisabledViaLua(self):
         """
@@ -642,19 +566,14 @@ class TestECSDisabledByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryNoEDNS(query, receivedQuery)
-        self.checkResponseNoEDNS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryNoEDNS(query, receivedQuery)
+            self.checkResponseNoEDNS(response, receivedResponse)
 
 class TestECSOverrideSetByRuleOrLua(DNSDistTest):
     """
@@ -692,19 +611,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.checkQueryEDNSWithECS(query, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.checkQueryEDNSWithECS(query, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithECSOverrideSetViaRule(self):
         """
@@ -724,19 +638,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
     def testWithECSOverrideSetViaLua(self):
         """
@@ -756,19 +665,14 @@ class TestECSOverrideSetByRuleOrLua(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseEDNSWithECS(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseEDNSWithECS(response, receivedResponse)
 
 class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
     """
@@ -809,19 +713,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSPrefixLengthOverriddenViaRule(self):
         """
@@ -841,19 +740,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSPrefixLengthOverriddenViaLua(self):
         """
@@ -873,19 +767,14 @@ class TestECSPrefixLengthSetByRuleOrLua(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
 class TestECSPrefixSetByRule(DNSDistTest):
     """
@@ -921,19 +810,14 @@ class TestECSPrefixSetByRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
 
     def testWithECSSetByRule(self):
         """
@@ -953,16 +837,11 @@ class TestECSPrefixSetByRule(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
-        self.checkResponseNoEDNS(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.checkQueryEDNSWithECS(expectedQuery, receivedQuery)
+            self.checkResponseNoEDNS(expectedResponse, receivedResponse)
index cdee353ee10d7c110326936c6b01e32558012d7f..d7bacbb6323805bc2d05c795bcd50033cd03087d 100644 (file)
@@ -9,119 +9,11 @@ from dnsdisttests import DNSDistTest, Queue
 import dns
 import dnsmessage_pb2
 
-
-class TestProtobuf(DNSDistTest):
+class DNSDistProtobufTest(DNSDistTest):
     _protobufServerPort = 4242
     _protobufQueue = Queue()
     _protobufServerID = 'dnsdist-server-1'
     _protobufCounter = 0
-    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
-    _config_template = """
-    luasmn = newSuffixMatchNode()
-    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
-
-    function alterProtobufResponse(dq, protobuf)
-      if luasmn:check(dq.qname) then
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then
-          requestor:truncate(24)
-        else
-          requestor:truncate(56)
-        end
-        protobuf:setRequestor(requestor)
-
-        local tableTags = {}
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-
-        protobuf:setTagArray(tableTags)
-
-        protobuf:setTag('TestLabel3,TestData3')
-
-        protobuf:setTag("Response,456")
-
-      else
-
-        local tableTags = {}                                   -- called by testProtobuf()
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-        protobuf:setTagArray(tableTags)
-
-        protobuf:setTag('TestLabel3,TestData3')
-
-        protobuf:setTag("Response,456")
-
-      end
-    end
-
-    function alterProtobufQuery(dq, protobuf)
-
-      if luasmn:check(dq.qname) then
-        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
-        if requestor:isIPv4() then
-          requestor:truncate(24)
-        else
-          requestor:truncate(56)
-        end
-        protobuf:setRequestor(requestor)
-
-        local tableTags = {}
-        tableTags = dq:getTagArray()                           -- get table from DNSQuery
-
-        local tablePB = {}
-          for k, v in pairs( tableTags) do
-          table.insert(tablePB, k .. "," .. v)
-        end
-
-        protobuf:setTagArray(tablePB)                          -- store table in protobuf
-        protobuf:setTag("Query,123")                           -- add another tag entry in protobuf
-
-        protobuf:setResponseCode(dnsdist.NXDOMAIN)             -- set protobuf response code to be NXDOMAIN
-
-        local strReqName = dq.qname:toString()                 -- get request dns name
-
-        protobuf:setProtobufResponseType()                     -- set protobuf to look like a response and not a query, with 0 default time
-
-        blobData = '\127' .. '\000' .. '\000' .. '\001'                -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
-
-        protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
-
-        protobuf:setBytes(65)                                  -- set the size of the query to confirm in checkProtobufBase
-
-      else
-
-        local tableTags = {}                                    -- called by testProtobuf()
-        table.insert(tableTags, "TestLabel1,TestData1")
-        table.insert(tableTags, "TestLabel2,TestData2")
-
-        protobuf:setTagArray(tableTags)
-        protobuf:setTag('TestLabel3,TestData3')
-        protobuf:setTag("Query,123")
-
-      end
-    end
-
-    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
-      local tt = {}
-      tt["TestLabel1"] = "TestData1"
-      tt["TestLabel2"] = "TestData2"
-
-      dq:setTagArray(tt)
-
-      dq:setTag("TestLabel3","TestData3")
-      return DNSAction.None, ""                                -- continue to the next rule
-    end
-
-    newServer{address="127.0.0.1:%s", useClientSubnet=true}
-    rl = newRemoteLogger('127.0.0.1:%s')
-
-    addAction(AllRule(), LuaAction(alterLuaFirst))                                                     -- Add tags to DNSQuery first
-
-    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'}))                             -- Send protobuf message before lookup
-
-    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'}))    -- Send protobuf message after lookup
-
-    """
 
     @classmethod
     def ProtobufListener(cls, port):
@@ -188,7 +80,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(msg.id, query.id)
         self.assertTrue(msg.HasField('inBytes'))
         self.assertTrue(msg.HasField('serverIdentity'))
-        self.assertEquals(msg.serverIdentity, self._protobufServerID)
+        self.assertEquals(msg.serverIdentity, self._protobufServerID.encode('utf-8'))
 
         if normalQueryResponse:
           # compare inBytes with length of query/response
@@ -244,6 +136,115 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(record.ttl, rttl)
         self.assertTrue(record.HasField('rdata'))
 
+class TestProtobuf(DNSDistProtobufTest):
+    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
+    _config_template = """
+    luasmn = newSuffixMatchNode()
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
+
+    function alterProtobufResponse(dq, protobuf)
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then
+          requestor:truncate(24)
+        else
+          requestor:truncate(56)
+        end
+        protobuf:setRequestor(requestor)
+
+        local tableTags = {}
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+
+        protobuf:setTagArray(tableTags)
+
+        protobuf:setTag('TestLabel3,TestData3')
+
+        protobuf:setTag("Response,456")
+
+      else
+
+        local tableTags = {}                                   -- called by testProtobuf()
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+        protobuf:setTagArray(tableTags)
+
+        protobuf:setTag('TestLabel3,TestData3')
+
+        protobuf:setTag("Response,456")
+
+      end
+    end
+
+    function alterProtobufQuery(dq, protobuf)
+
+      if luasmn:check(dq.qname) then
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then
+          requestor:truncate(24)
+        else
+          requestor:truncate(56)
+        end
+        protobuf:setRequestor(requestor)
+
+        local tableTags = {}
+        tableTags = dq:getTagArray()                           -- get table from DNSQuery
+
+        local tablePB = {}
+          for k, v in pairs( tableTags) do
+          table.insert(tablePB, k .. "," .. v)
+        end
+
+        protobuf:setTagArray(tablePB)                          -- store table in protobuf
+        protobuf:setTag("Query,123")                           -- add another tag entry in protobuf
+
+        protobuf:setResponseCode(dnsdist.NXDOMAIN)             -- set protobuf response code to be NXDOMAIN
+
+        local strReqName = dq.qname:toString()                 -- get request dns name
+
+        protobuf:setProtobufResponseType()                     -- set protobuf to look like a response and not a query, with 0 default time
+
+        blobData = '\127' .. '\000' .. '\000' .. '\001'                -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
+
+        protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
+
+        protobuf:setBytes(65)                                  -- set the size of the query to confirm in checkProtobufBase
+
+      else
+
+        local tableTags = {}                                    -- called by testProtobuf()
+        table.insert(tableTags, "TestLabel1,TestData1")
+        table.insert(tableTags, "TestLabel2,TestData2")
+
+        protobuf:setTagArray(tableTags)
+        protobuf:setTag('TestLabel3,TestData3')
+        protobuf:setTag("Query,123")
+
+      end
+    end
+
+    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
+      local tt = {}
+      tt["TestLabel1"] = "TestData1"
+      tt["TestLabel2"] = "TestData2"
+
+      dq:setTagArray(tt)
+
+      dq:setTag("TestLabel3","TestData3")
+      return DNSAction.None, ""                                -- continue to the next rule
+    end
+
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    rl = newRemoteLogger('127.0.0.1:%s')
+
+    addAction(AllRule(), LuaAction(alterLuaFirst))                                                     -- Add tags to DNSQuery first
+
+    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'}))                             -- Send protobuf message before lookup
+
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'}))    -- Send protobuf message after lookup
+
+    """
+
     def testProtobuf(self):
         """
         Protobuf: Send data to a protobuf server
@@ -291,7 +292,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
+        self.assertEquals(rr.rdata.decode('utf-8'), target)
         rr = msg.response.rrs[1]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
@@ -319,7 +320,7 @@ class TestProtobuf(DNSDistTest):
         self.assertEquals(len(msg.response.rrs), 2)
         rr = msg.response.rrs[0]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
-        self.assertEquals(rr.rdata, target)
+        self.assertEquals(rr.rdata.decode('utf-8'), target)
         rr = msg.response.rrs[1]
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
@@ -390,3 +391,92 @@ class TestProtobuf(DNSDistTest):
         for rr in msg.response.rrs:
             self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
             self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+class TestProtobufIPCipher(DNSDistProtobufTest):
+    _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    key = makeIPCipherKey("some 16-byte key")
+    rl = newRemoteLogger('127.0.0.1:%s')
+    addAction(AllRule(), RemoteLogAction(rl, nil, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message before lookup
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, nil, true, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message after lookup
+
+    """
+
+    def testProtobuf(self):
+        """
+        Protobuf: Send data to a protobuf server
+        """
+        name = 'query.protobuf-ipcipher.tests.powerdns.com.'
+
+        target = 'target.protobuf-ipcipher.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    target)
+        response.answer.append(rrset)
+
+        rrset = dns.rrset.from_text(target,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+
+        # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '108.41.239.98')
+
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '108.41.239.98')
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
index cfd3d614bc067e2ffd2f883d62c5420a60f8f089..4b9ce8b27a9cd372ceeff18555d2f21402abbc30 100644 (file)
@@ -23,11 +23,10 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowOneAR(self):
         """
@@ -45,19 +44,14 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseTwoAR(self):
         """
@@ -76,11 +70,10 @@ class TestRecordsCountOnlyOneAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
 
@@ -102,11 +95,10 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowTwoAN(self):
         """
@@ -126,19 +118,14 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         response = dns.message.make_response(query)
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountRefuseFourAN(self):
         """
@@ -160,11 +147,10 @@ class TestRecordsCountMoreThanOneLessThanFour(DNSDistTest):
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRecordsCountNothingInNS(DNSDistTest):
 
@@ -193,11 +179,10 @@ class TestRecordsCountNothingInNS(DNSDistTest):
         expectedResponse.set_rcode(dns.rcode.REFUSED)
         expectedResponse.authority.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 
     def testRecordsCountAllowEmptyNS(self):
@@ -216,19 +201,14 @@ class TestRecordsCountNothingInNS(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestRecordsCountNoOPTInAR(DNSDistTest):
 
@@ -249,11 +229,10 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
     def testRecordsCountAllowNoOPTInAR(self):
         """
@@ -271,19 +250,14 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testRecordsCountAllowTwoARButNoOPT(self):
         """
@@ -312,16 +286,11 @@ class TestRecordsCountNoOPTInAR(DNSDistTest):
                                                    dns.rdatatype.A,
                                                    '127.0.0.1'))
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
index 494143c6d657d8420805b574bdc31b2e50c30ef3..ffbdf94266543b5e99cffbf844324c3d3a5569bb 100644 (file)
@@ -122,15 +122,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testNotDropped(self):
         """
@@ -143,15 +140,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestResponseRuleQNameAllowed(DNSDistTest):
 
@@ -172,15 +166,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testNotAllowed(self):
         """
@@ -193,15 +184,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
 class TestResponseRuleEditTTL(DNSDistTest):
 
@@ -236,19 +224,14 @@ class TestResponseRuleEditTTL(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+            self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
 
 class TestResponseLuaActionReturnSyntax(DNSDistTest):
 
@@ -297,12 +280,9 @@ class TestResponseLuaActionReturnSyntax(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
index dfb8c1da6b7d639eac1808c227aade1589bbb1eb..e0bda09aebc4893b116377a8956efd4ab997bef3 100644 (file)
@@ -29,15 +29,12 @@ class TestRoutingPoolRouting(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testDefaultPool(self):
         """
@@ -50,11 +47,10 @@ class TestRoutingPoolRouting(DNSDistTest):
         name = 'notpool.routing.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, None)
 
 class TestRoutingQPSPoolRouting(DNSDistTest):
     _config_template = """
@@ -274,14 +270,12 @@ class TestRoutingOrder(DNSDistTest):
         response.answer.append(rrset)
 
         for _ in range(numberOfQueries):
-            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
-            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-            receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEquals(query, receivedQuery)
+                self.assertEquals(response, receivedResponse)
 
         total = 0
         if 'UDP Responder' in self._responsesCounter:
@@ -307,12 +301,10 @@ class TestRoutingNoServer(DNSDistTest):
         expectedResponse = dns.message.make_response(query)
         expectedResponse.set_rcode(dns.rcode.SERVFAIL)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertEquals(receivedResponse, expectedResponse)
-
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEquals(receivedResponse, expectedResponse)
 
 class TestRoutingWRandom(DNSDistTest):
 
index 601167e824572eb550cc030b1657bd8cf3be86a0..4cedcd32bf59294dfdc00fea6839cc146d168946 100644 (file)
@@ -30,13 +30,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionAAAA(self):
         """
@@ -57,13 +55,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionCNAME(self):
         """
@@ -84,13 +80,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     'cnameaction.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiA(self):
         """
@@ -111,13 +105,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '192.0.2.2', '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiAAAA(self):
         """
@@ -138,13 +130,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1', '2001:DB8::2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testSpoofActionMultiANY(self):
         """
@@ -173,13 +163,11 @@ class TestSpoofingSpoof(DNSDistTest):
                                     '2001:DB8::1', '2001:DB8::2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestSpoofingLuaSpoof(DNSDistTest):
 
@@ -222,13 +210,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     '192.0.2.1', '192.0.2.2')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAAAA(self):
         """
@@ -249,13 +235,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     '2001:DB8::1')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAWithCNAME(self):
         """
@@ -276,13 +260,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     'spoofedcname.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testLuaSpoofAAAAWithCNAME(self):
         """
@@ -303,13 +285,11 @@ class TestSpoofingLuaSpoof(DNSDistTest):
                                     'spoofedcname.spoofing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
 class TestSpoofingLuaWithStatistics(DNSDistTest):
 
@@ -367,10 +347,8 @@ class TestSpoofingLuaWithStatistics(DNSDistTest):
         self.assertTrue(receivedResponse)
         self.assertEquals(expectedResponse2, receivedResponse)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponseAfterwards, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponseAfterwards, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponseAfterwards, receivedResponse)
index a1672e494128142454619d6676942316fdaa91a6..984c7d1e0856b79f9957e1a06c6f7a71dcff7343 100644 (file)
@@ -2,7 +2,12 @@
 import struct
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+
+try:
+    range = xrange
+except NameError:
+    pass
 
 class TestTCPKeepAlive(DNSDistTest):
     """
@@ -13,7 +18,7 @@ class TestTCPKeepAlive(DNSDistTest):
 
     _tcpIdleTimeout = 20
     _maxTCPQueriesPerConn = 99
-    _maxTCPConnsPerClient = 3
+    _maxTCPConnsPerClient = 100
     _maxTCPConnDuration = 99
     _config_template = """
     newServer{address="127.0.0.1:%s"}
@@ -23,6 +28,7 @@ class TestTCPKeepAlive(DNSDistTest):
     setMaxTCPConnectionDuration(%s)
     pc = newPacketCache(100, 86400, 1)
     getPool(""):setCache(pc)
+    addAction("largernumberofconnections.tcpka.tests.powerdns.com.", SkipCacheAction())
     addAction("refused.tcpka.tests.powerdns.com.", RCodeAction(dnsdist.REFUSED))
     addAction("dropped.tcpka.tests.powerdns.com.", DropAction())
     addResponseAction("dropped-response.tcpka.tests.powerdns.com.", DropResponseAction())
@@ -199,6 +205,49 @@ class TestTCPKeepAlive(DNSDistTest):
         conn.close()
         self.assertEqual(count, 0)
 
+    def testTCPKaLargeNumberOfConnections(self):
+        """
+        TCP KeepAlive: Large number of connections
+        """
+        name = 'largernumberofconnections.tcpka.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        #expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        # number of connections
+        numConns = 50
+        # number of queries per connections
+        numQueriesPerConn = 4
+
+        conns = []
+        start = time.time()
+        for idx in range(numConns):
+            conns.append(self.openTCPConnection())
+
+        count = 0
+        for idx in range(numConns * numQueriesPerConn):
+            try:
+                conn = conns[idx % numConns]
+                self.sendTCPQueryOverConnection(conn, query, response=expectedResponse)
+                response = self.recvTCPResponseOverConnection(conn)
+                if response is None:
+                    break
+                self.assertEquals(expectedResponse, response)
+                count = count + 1
+            except:
+                pass
+
+        for con in conns:
+          conn.close()
+
+        self.assertEqual(count, numConns * numQueriesPerConn)
+
 class TestTCPKeepAliveNoDownstreamDrop(DNSDistTest):
     """
     This test makes sure that dnsdist drops the TCP connection
index e67a152d375db844a39c2a02a765ba6745169a7b..ec20c929f5966591f857c55e9e0d9a2e52730f37 100644 (file)
@@ -2,7 +2,12 @@
 import struct
 import time
 import dns
-from dnsdisttests import DNSDistTest, range
+from dnsdisttests import DNSDistTest
+
+try:
+  range = xrange
+except NameError:
+  pass
 
 class TestTCPLimits(DNSDistTest):
 
@@ -101,19 +106,20 @@ class TestTCPLimits(DNSDistTest):
         conn.send(struct.pack("!H", 65535))
 
         count = 0
-        while count < (self._maxTCPConnDuration * 2):
+        while count < (self._maxTCPConnDuration * 20):
             try:
                 # sleeping for only one second keeps us below the
                 # idle timeout (setTCPRecvTimeout())
-                time.sleep(1)
-                conn.send('A')
+                time.sleep(0.1)
+                conn.send(b'A')
                 count = count + 1
-            except:
+            except Exception as e:
+                print("Exception: %s!" % (e))
                 break
 
         end = time.time()
 
-        self.assertAlmostEquals(count, self._maxTCPConnDuration, delta=2)
+        self.assertAlmostEquals(count / 10, self._maxTCPConnDuration, delta=2)
         self.assertAlmostEquals(end - start, self._maxTCPConnDuration, delta=2)
 
         conn.close()
index 68aaa83e6d18562556964e00ff2047f7f63e407d..b31c6c2f9bfcf4933bdc99103b5eb2febbf2a035 100644 (file)
@@ -38,3 +38,55 @@ class TestTLS(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
+
+    def testTLKA(self):
+        """
+        TLS: Several queries over the same connection
+        """
+        name = 'ka.tls.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+        for idx in range(5):
+            self.sendTCPQueryOverConnection(conn, query, response=response)
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+
+    def testTLSPipelining(self):
+        """
+        TLS: Several queries over the same connection without waiting for the responses
+        """
+        name = 'pipelining.tls.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
+
+        for idx in range(100):
+            self.sendTCPQueryOverConnection(conn, query, response=response)
+
+        for idx in range(100):
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
index 6da32b9a0f14f3ce18d7751a5ca6463b2fc84fa8..736c87bf950006f0200a5ebcb1888a11cd60bee9 100644 (file)
@@ -55,19 +55,14 @@ class TestBasics(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testQuestionMatchTagAndValue(self):
         """
@@ -85,13 +80,11 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.50')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testQuestionMatchTagOnly(self):
         """
@@ -109,13 +102,11 @@ class TestBasics(DNSDistTest):
                                     '1.2.3.100')
         expectedResponse.answer.append(rrset)
 
-        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
-        self.assertTrue(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testResponseNoMatch(self):
         """
@@ -131,19 +122,14 @@ class TestBasics(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testResponseMatchTagAndValue(self):
         """
@@ -165,21 +151,14 @@ class TestBasics(DNSDistTest):
         # we will set TC if the tag matches
         expectedResponse.flags |= dns.flags.TC
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        print(expectedResponse)
-        print(receivedResponse)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
 
     def testResponseMatchResponseTagMatches(self):
         """
@@ -201,16 +180,11 @@ class TestBasics(DNSDistTest):
         # we will set QR=0 if the tag matches
         expectedResponse.flags &= ~dns.flags.QR
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(expectedResponse, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(expectedResponse, receivedResponse)
index a9c74b58556bdeaf7f3e38eb9048f33c031e3ee7..b87b519f7ad7b9fe11f42ee82ec84383ab497cb8 100644 (file)
@@ -20,7 +20,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("added.trailing.tests.powerdns.com.", replaceTrailingData)
+    addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
 
     function fillBuffer(dq)
         local available = dq.size - dq.len
@@ -31,7 +31,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("max.trailing.tests.powerdns.com.", fillBuffer)
+    addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
 
     function exceedBuffer(dq)
         local available = dq.size - dq.len
@@ -42,7 +42,7 @@ class TestTrailingDataToBackend(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
+    addAction("limited.trailing.tests.powerdns.com.", LuaAction(exceedBuffer))
     """
     @classmethod
     def startResponders(cls):
@@ -80,8 +80,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -108,8 +106,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -136,8 +132,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (_, receivedResponse) = sender(query, response)
             self.assertTrue(receivedResponse)
             self.assertEquals(receivedResponse, expectedResponse)
@@ -161,8 +155,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -183,13 +175,13 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
+    addAction("removed.trailing.tests.powerdns.com.", LuaAction(removeTrailingData))
 
     function reportTrailingData(dq)
         local tail = dq:getTrailingData()
         return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
     end
-    addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("echoed.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
 
     function replaceTrailingData(dq)
         local success = dq:setTrailingData("ABC")
@@ -198,8 +190,8 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("replaced.trailing.tests.powerdns.com.", replaceTrailingData)
-    addLuaAction("replaced.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
 
     function reportTrailingHex(dq)
         local tail = dq:getTrailingData()
@@ -208,7 +200,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end)
         return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
     end
-    addLuaAction("echoed-hex.trailing.tests.powerdns.com.", reportTrailingHex)
+    addAction("echoed-hex.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
 
     function replaceTrailingData_unsafe(dq)
         local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
@@ -217,8 +209,8 @@ class TestTrailingDataToDnsdist(DNSDistTest):
         end
         return DNSAction.None, ""
     end
-    addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", replaceTrailingData_unsafe)
-    addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", reportTrailingHex)
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData_unsafe))
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
     """
 
     def testTrailingDropped(self):
@@ -243,8 +235,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             sender = getattr(self, method)
 
             # Verify that queries with no trailing data make it through.
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -253,8 +243,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertEquals(response, receivedResponse)
 
             # Verify that queries with trailing data don't make it through.
-            # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertEquals(receivedResponse, None)
 
@@ -278,8 +266,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -309,8 +295,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -338,8 +322,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -367,8 +349,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -396,8 +376,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
index 7da9f05040c362b54537029f5c770562268bce15..26038091e8e18cb70e872ae4458a585266a75412 100644 (file)
@@ -184,6 +184,8 @@ class testECSByNameLarger(ECSTest):
     _config_template = """edns-subnet-whitelist=ecs-echo.example.
 ecs-ipv4-bits=32
 forward-zones=ecs-echo.example=%s.21
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -244,6 +246,8 @@ class testIncomingECSByName(ECSTest):
 use-incoming-edns-subnet=yes
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=2001:db8::42
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -283,6 +287,8 @@ use-incoming-edns-subnet=yes
 ecs-ipv4-bits=32
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=192.168.0.1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -313,6 +319,8 @@ use-incoming-edns-subnet=yes
 ecs-ipv4-bits=16
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=192.168.0.1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'])
 
     def testSendECS(self):
@@ -339,6 +347,8 @@ class testIncomingECSByNameV6(ECSTest):
     _config_template = """edns-subnet-whitelist=ecs-echo.example.
 use-incoming-edns-subnet=yes
 ecs-ipv6-bits=128
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
 forward-zones=ecs-echo.example=%s.21
 query-local-address6=::1
     """ % (os.environ['PREFIX'])
@@ -423,6 +433,8 @@ class testIncomingECSByIP(ECSTest):
 use-incoming-edns-subnet=yes
 forward-zones=ecs-echo.example=%s.21
 ecs-scope-zero-address=::1
+ecs-ipv4-cache-bits=32
+ecs-ipv6-cache-bits=128
     """ % (os.environ['PREFIX'], os.environ['PREFIX'])
 
     def testSendECS(self):