From f4fef543d41fc013c820d1bc981725487d5a3a64 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Sun, 24 Nov 2019 19:20:00 +0000 Subject: [PATCH] downloader: Check DNS for most recent version Signed-off-by: Michael Tremer --- Makefile.am | 5 + configure.ac | 25 +++-- man/location-downloader.txt | 6 ++ src/libloc.sym | 2 + src/loc/format.h | 3 + src/loc/resolv.h | 26 ++++++ src/python/location-downloader.in | 113 ++++++++++++----------- src/python/locationmodule.c | 25 +++++ src/resolv.c | 147 ++++++++++++++++++++++++++++++ 9 files changed, 289 insertions(+), 63 deletions(-) create mode 100644 src/loc/resolv.h create mode 100644 src/resolv.c diff --git a/Makefile.am b/Makefile.am index b937824..63efb17 100644 --- a/Makefile.am +++ b/Makefile.am @@ -89,6 +89,7 @@ pkginclude_HEADERS = \ src/loc/network.h \ src/loc/private.h \ src/loc/stringpool.h \ + src/loc/resolv.h \ src/loc/writer.h lib_LTLIBRARIES = \ @@ -100,6 +101,7 @@ src_libloc_la_SOURCES = \ src/country.c \ src/database.c \ src/network.c \ + src/resolv.c \ src/stringpool.c \ src/writer.c @@ -120,6 +122,9 @@ else src_libloc_la_LDFLAGS += -export-symbols $(top_srcdir)/src/libloc.sym endif +src_libloc_la_LIBADD = \ + $(RESOLV_LIBS) + src_libloc_la_DEPENDENCIES = \ ${top_srcdir}/src/libloc.sym diff --git a/configure.ac b/configure.ac index cd224eb..57c5b04 100644 --- a/configure.ac +++ b/configure.ac @@ -62,23 +62,26 @@ AS_IF([test "x$enable_debug" = "xyes"], [ AC_CHECK_HEADERS_ONCE([ arpa/inet.h \ + arpa/nameser.h \ endian.h \ netinet/in.h \ + resolv.h \ string.h \ ]) AC_CHECK_FUNCS([ \ - be16toh \ - be32toh \ - be64toh \ - htobe16 \ - htobe32 \ - htobe64 \ - mmap \ - munmap \ + be16toh \ + be32toh \ + be64toh \ + htobe16 \ + htobe32 \ + htobe64 \ + mmap \ + munmap \ + res_query \ __secure_getenv \ secure_getenv \ - qsort \ + qsort \ ]) my_CFLAGS="\ @@ -145,6 +148,10 @@ AX_PROG_PERL_MODULES(ExtUtils::MakeMaker,, AC_MSG_WARN(Need some Perl modules)) AC_ARG_ENABLE(perl, AS_HELP_STRING([--disable-perl], [do not build the perl modules]), [],[enable_perl=yes]) AM_CONDITIONAL(ENABLE_PERL, test "$enable_perl" = "yes") +dnl Checking for libresolv +AC_CHECK_LIB(resolv, ns_msg_getflag, [LIBS="-lresolv $LIBS"], AC_MSG_ERROR([libresolv has not been found]), -lresolv) +RESOLV_LIBS="${LIBS}" + AC_CONFIG_HEADERS(config.h) AC_CONFIG_FILES([ Makefile diff --git a/man/location-downloader.txt b/man/location-downloader.txt index a03b60f..5faa5cd 100644 --- a/man/location-downloader.txt +++ b/man/location-downloader.txt @@ -38,6 +38,12 @@ The 'location-downloader' command will normally exit with code zero. If there has been a problem and the requested action could not be performed, the exit code is unequal to zero. +== HOW IT WORKS +The downloader checks a DNS record for the latest version of the database. +It will then try to download a file with that version from a mirror server. +If the downloaded file is outdated, the next mirror will be tried until we +have found a file that is recent enough. + == BUGS Please report all bugs to the bugtracker at https://bugzilla.ipfire.org/. diff --git a/src/libloc.sym b/src/libloc.sym index 7cd7405..12f82ba 100644 --- a/src/libloc.sym +++ b/src/libloc.sym @@ -26,6 +26,7 @@ global: loc_unref; loc_set_log_priority; loc_new; + loc_discover_latest_version; # AS loc_as_cmp; @@ -69,6 +70,7 @@ global: loc_database_enumerator_ref; loc_database_enumerator_set_asn; loc_database_enumerator_set_country_code; + loc_database_enumerator_set_flag; loc_database_enumerator_set_string; loc_database_enumerator_unref; diff --git a/src/loc/format.h b/src/loc/format.h index f93b015..3679828 100644 --- a/src/loc/format.h +++ b/src/loc/format.h @@ -25,6 +25,9 @@ #define LOC_DATABASE_VERSION 0 +#define STR(x) #x +#define LOC_DATABASE_DOMAIN_LATEST(version) "_latest._v" STR(version) ".location.ipfire.org" + #define LOC_DATABASE_PAGE_SIZE 4096 struct loc_database_magic { diff --git a/src/loc/resolv.h b/src/loc/resolv.h new file mode 100644 index 0000000..3b5e990 --- /dev/null +++ b/src/loc/resolv.h @@ -0,0 +1,26 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2019 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LIBLOC_RESOLV_H +#define LIBLOC_RESOLV_H + +#include + +#include + +int loc_discover_latest_version(struct loc_ctx* ctx, const char* domain, time_t* t); + +#endif diff --git a/src/python/location-downloader.in b/src/python/location-downloader.in index 961c5df..4fdf404 100644 --- a/src/python/location-downloader.in +++ b/src/python/location-downloader.in @@ -18,6 +18,7 @@ ############################################################################### import argparse +import datetime import gettext import logging import logging.handlers @@ -72,12 +73,6 @@ def _(singular, plural=None, n=None): return gettext.dgettext("libloc", singular) -class NotModifiedError(Exception): - """ - Raised when the file has not been modified on the server - """ - pass - class Downloader(object): def __init__(self, mirrors): @@ -139,10 +134,6 @@ class Downloader(object): for header in e.headers: log.debug(" %s: %s" % (header, e.headers[header])) - # Handle 304 - if e.code == 304: - raise NotModifiedError() from e - # Raise all other errors raise e @@ -154,12 +145,12 @@ class Downloader(object): return res - def download(self, url, mtime=None, **kwargs): + def download(self, url, timestamp=None, **kwargs): headers = {} - if mtime: - headers["If-Modified-Since"] = time.strftime( - "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime), + if timestamp: + headers["If-Modified-Since"] = timestamp.strftime( + "%a, %d %b %Y %H:%M:%S GMT", ) t = tempfile.NamedTemporaryFile(delete=False) @@ -184,37 +175,59 @@ class Downloader(object): if buf: t.write(buf) - # Write all data to disk - t.flush() - - # Nothing to do when the database on the server is up to date - except NotModifiedError: - log.info("Local database is up to date") - return + # Write all data to disk + t.flush() # Catch decompression errors except lzma.LZMAError as e: log.warning("Could not decompress downloaded file: %s" % e) continue - # XXX what do we catch here? except urllib.error.HTTPError as e: - if e.code == 404: - continue + # The file on the server was too old + if e.code == 304: + log.warning("%s is serving an outdated database. Trying next mirror..." % mirror) - # Truncate the target file and drop downloaded content - try: - t.truncate() - except OSError: - pass + # Log any other HTTP errors + else: + log.warning("%s reported: %s" % (mirror, e)) + + # Throw away any downloaded content and try again + t.truncate() - raise e + else: + # Check if the downloaded database is recent + if not self._check_database(t, timestamp): + log.warning("Downloaded database is outdated. Trying next mirror...") - # Return temporary file - return t + # Throw away the data and try again + t.truncate() + continue + + # Return temporary file + return t raise FileNotFoundError(url) + def _check_database(self, f, timestamp=None): + """ + Checks the downloaded database if it can be opened, + verified and if it is recent enough + """ + log.debug("Opening downloaded database at %s" % f.name) + + db = location.Database(f.name) + + # Database is not recent + if timestamp and db.created_at < timestamp.timestamp(): + return False + + log.info("Downloaded new database from %s" % (time.strftime( + "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at), + ))) + + return True + class CLI(object): def __init__(self): @@ -271,20 +284,30 @@ class CLI(object): sys.exit(0) def handle_update(self, ns): - mtime = None + # Fetch the version we need from DNS + t = location.discover_latest_version() + + # Parse timestamp into datetime format + try: + timestamp = datetime.datetime.fromtimestamp(t) + except: + raise # Open database try: db = location.Database(ns.database) - # Get mtime of the old file - mtime = os.path.getmtime(ns.database) + # Check if we are already on the latest version + if db.created_at >= timestamp.timestamp(): + log.info("Already on the latest version") + return + except FileNotFoundError as e: db = None # Try downloading a new database try: - t = self.downloader.download(DATABASE_FILENAME, mtime=mtime) + t = self.downloader.download(DATABASE_FILENAME, timestamp=timestamp) # If no file could be downloaded, log a message except FileNotFoundError as e: @@ -295,24 +318,6 @@ class CLI(object): if not t: return 0 - # Save old database creation time - created_at = db.created_at if db else 0 - - # Try opening the downloaded file - try: - db = location.Database(t.name) - except Exception as e: - raise e - - # Check if the downloaded file is newer - if db.created_at <= created_at: - log.warning("Downloaded database is older than the current version") - return 1 - - log.info("Downloaded new database from %s" % (time.strftime( - "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at), - ))) - # Write temporary file to destination shutil.copyfile(t.name, ns.database) diff --git a/src/python/locationmodule.c b/src/python/locationmodule.c index d413260..78cf075 100644 --- a/src/python/locationmodule.c +++ b/src/python/locationmodule.c @@ -17,6 +17,8 @@ #include #include +#include + #include "locationmodule.h" #include "as.h" #include "country.h" @@ -46,7 +48,30 @@ static PyObject* set_log_level(PyObject* m, PyObject* args) { Py_RETURN_NONE; } +static PyObject* discover_latest_version(PyObject* m, PyObject* args) { + const char* domain = NULL; + + if (!PyArg_ParseTuple(args, "|s", &domain)) + return NULL; + + time_t t = 0; + + int r = loc_discover_latest_version(loc_ctx, domain, &t); + if (r) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + + return PyLong_FromUnsignedLong(t); +} + static PyMethodDef location_module_methods[] = { + { + "discover_latest_version", + (PyCFunction)discover_latest_version, + METH_VARARGS, + NULL, + }, { "set_log_level", (PyCFunction)set_log_level, diff --git a/src/resolv.c b/src/resolv.c new file mode 100644 index 0000000..3ad0f27 --- /dev/null +++ b/src/resolv.c @@ -0,0 +1,147 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2019 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include +#include + +static int parse_timestamp(const unsigned char* txt, time_t* t) { + struct tm ts; + + // Parse timestamp + char* p = strptime((const char*)txt, "%a, %d %b %Y %H:%M:%S GMT", &ts); + + // If the whole string has been parsed, we convert the parse value to time_t + if (p && !*p) { + *t = mktime(&ts); + + // Otherwise we reset t + } else { + *t = 0; + return -1; + } + + return 0; +} + +LOC_EXPORT int loc_discover_latest_version(struct loc_ctx* ctx, const char* domain, time_t* t) { + // Initialise the resolver + int r = res_init(); + if (r) { + ERROR(ctx, "res_init() failed\n"); + return r; + } + + // Fall back to default domain + if (!domain) + domain = LOC_DATABASE_DOMAIN_LATEST(LOC_DATABASE_VERSION); + + unsigned char answer[PACKETSZ]; + int len; + + DEBUG(ctx, "Querying %s\n", domain); + + // Send a query + if ((len = res_query(domain, C_IN, T_TXT, answer, sizeof(answer))) < 0 || len > PACKETSZ) { + ERROR(ctx, "Could not query %s: \n", domain); + + return -1; + } + + unsigned char* end = answer + len; + unsigned char* payload = answer + sizeof(HEADER); + + // Expand domain name + char host[128]; + if ((len = dn_expand(answer, end, payload, host, sizeof(host))) < 0) { + ERROR(ctx, "dn_expand() failed\n"); + return -1; + } + + // Payload starts after hostname + payload += len; + + if (payload > end - 4) { + ERROR(ctx, "DNS reply too short\n"); + return -1; + } + + int type; + GETSHORT(type, payload); + if (type != T_TXT) { + ERROR(ctx, "DNS reply of unexpected type: %d\n", type); + return -1; + } + + // Skip class + payload += INT16SZ; + + // Walk through CNAMEs + unsigned int size = 0; + int ttl; + do { + payload += size; + + if ((len = dn_expand(answer, end, payload, host, sizeof(host))) < 0) { + ERROR(ctx, "dn_expand() failed\n"); + return -1; + } + + payload += len; + + if (payload > end - 10) { + ERROR(ctx, "DNS reply too short\n"); + return -1; + } + + // Skip type, class, ttl + GETSHORT(type, payload); + payload += INT16SZ; + GETLONG(ttl, payload); + + // Read size + GETSHORT(size, payload); + if (payload + size < answer || payload + size > end) { + ERROR(ctx, "DNS RR overflow\n"); + return -1; + } + } while (type == T_CNAME); + + if (type != T_TXT) { + ERROR(ctx, "Not a TXT record\n"); + return -1; + } + + if (!size || (len = *payload) >= size || !len) { + ERROR(ctx, "Broken TXT record (len = %d, size = %d)\n", len, size); + return -1; + } + + // Get start of the string + unsigned char* txt = payload + 1; + txt[len] = '\0'; + + DEBUG(ctx, "Resolved to: %s\n", txt); + + // Parse timestamp + r = parse_timestamp(txt, t); + + return r; +} -- 2.39.2