]> git.ipfire.org Git - location/libloc.git/commitdiff
downloader: Check DNS for most recent version
authorMichael Tremer <michael.tremer@ipfire.org>
Sun, 24 Nov 2019 19:20:00 +0000 (19:20 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sun, 24 Nov 2019 19:35:06 +0000 (19:35 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
configure.ac
man/location-downloader.txt
src/libloc.sym
src/loc/format.h
src/loc/resolv.h [new file with mode: 0644]
src/python/location-downloader.in
src/python/locationmodule.c
src/resolv.c [new file with mode: 0644]

index b937824f2e025e8353ef916d7010ff1182bd02ca..63efb17151530f3cb12cfcd74878c379d43f89f6 100644 (file)
@@ -89,6 +89,7 @@ pkginclude_HEADERS = \
        src/loc/network.h \
        src/loc/private.h \
        src/loc/stringpool.h \
+       src/loc/resolv.h \
        src/loc/writer.h
 
 lib_LTLIBRARIES = \
@@ -100,6 +101,7 @@ src_libloc_la_SOURCES = \
        src/country.c \
        src/database.c \
        src/network.c \
+       src/resolv.c \
        src/stringpool.c \
        src/writer.c
 
@@ -120,6 +122,9 @@ else
 src_libloc_la_LDFLAGS += -export-symbols $(top_srcdir)/src/libloc.sym
 endif
 
+src_libloc_la_LIBADD = \
+       $(RESOLV_LIBS)
+
 src_libloc_la_DEPENDENCIES = \
        ${top_srcdir}/src/libloc.sym
 
index cd224eb423ee201334617d3093ad5adebe20e564..57c5b0499c3f17131d0c2f84e37299188de35d47 100644 (file)
@@ -62,23 +62,26 @@ AS_IF([test "x$enable_debug" = "xyes"], [
 
 AC_CHECK_HEADERS_ONCE([
        arpa/inet.h \
+    arpa/nameser.h \
        endian.h \
        netinet/in.h \
+    resolv.h \
        string.h \
 ])
 
 AC_CHECK_FUNCS([ \
-        be16toh \
-        be32toh \
-        be64toh \
-        htobe16 \
-        htobe32 \
-        htobe64 \
-        mmap \
-        munmap \
+    be16toh \
+    be32toh \
+    be64toh \
+    htobe16 \
+    htobe32 \
+    htobe64 \
+    mmap \
+    munmap \
+    res_query \
        __secure_getenv \
        secure_getenv \
-        qsort \
+    qsort \
 ])
 
 my_CFLAGS="\
@@ -145,6 +148,10 @@ AX_PROG_PERL_MODULES(ExtUtils::MakeMaker,, AC_MSG_WARN(Need some Perl modules))
 AC_ARG_ENABLE(perl, AS_HELP_STRING([--disable-perl], [do not build the perl modules]), [],[enable_perl=yes])
 AM_CONDITIONAL(ENABLE_PERL, test "$enable_perl" = "yes")
 
+dnl Checking for libresolv
+AC_CHECK_LIB(resolv, ns_msg_getflag, [LIBS="-lresolv $LIBS"], AC_MSG_ERROR([libresolv has not been found]), -lresolv)
+RESOLV_LIBS="${LIBS}"
+
 AC_CONFIG_HEADERS(config.h)
 AC_CONFIG_FILES([
         Makefile
index a03b60f96bc5f8d3818a2c46b778e1683c449948..5faa5cd2ca83fc9e1ab7c130e22504269f9c0401 100644 (file)
@@ -38,6 +38,12 @@ The 'location-downloader' command will normally exit with code zero.
 If there has been a problem and the requested action could not be performed,
 the exit code is unequal to zero.
 
+== HOW IT WORKS
+The downloader checks a DNS record for the latest version of the database.
+It will then try to download a file with that version from a mirror server.
+If the downloaded file is outdated, the next mirror will be tried until we
+have found a file that is recent enough.
+
 == BUGS
 Please report all bugs to the bugtracker at https://bugzilla.ipfire.org/.
 
index 7cd7405c273f92ce800da0994d48c8db342b1b3f..12f82ba88a8a97b8f8c4a25b1d2ede7e4640f24d 100644 (file)
@@ -26,6 +26,7 @@ global:
        loc_unref;
        loc_set_log_priority;
        loc_new;
+       loc_discover_latest_version;
 
        # AS
        loc_as_cmp;
@@ -69,6 +70,7 @@ global:
        loc_database_enumerator_ref;
        loc_database_enumerator_set_asn;
        loc_database_enumerator_set_country_code;
+       loc_database_enumerator_set_flag;
        loc_database_enumerator_set_string;
        loc_database_enumerator_unref;
 
index f93b015eb4482489670bf7bde2638041484b9c97..3679828d1a9d4cbd04a31bf06c5f19c8b9f1af0b 100644 (file)
@@ -25,6 +25,9 @@
 
 #define LOC_DATABASE_VERSION    0
 
+#define STR(x) #x
+#define LOC_DATABASE_DOMAIN_LATEST(version) "_latest._v" STR(version) ".location.ipfire.org"
+
 #define LOC_DATABASE_PAGE_SIZE  4096
 
 struct loc_database_magic {
diff --git a/src/loc/resolv.h b/src/loc/resolv.h
new file mode 100644 (file)
index 0000000..3b5e990
--- /dev/null
@@ -0,0 +1,26 @@
+/*
+       libloc - A library to determine the location of someone on the Internet
+
+       Copyright (C) 2019 IPFire Development Team <info@ipfire.org>
+
+       This library is free software; you can redistribute it and/or
+       modify it under the terms of the GNU Lesser General Public
+       License as published by the Free Software Foundation; either
+       version 2.1 of the License, or (at your option) any later version.
+
+       This library is distributed in the hope that it will be useful,
+       but WITHOUT ANY WARRANTY; without even the implied warranty of
+       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+       Lesser General Public License for more details.
+*/
+
+#ifndef LIBLOC_RESOLV_H
+#define LIBLOC_RESOLV_H
+
+#include <time.h>
+
+#include <loc/libloc.h>
+
+int loc_discover_latest_version(struct loc_ctx* ctx, const char* domain, time_t* t);
+
+#endif
index 961c5dffdaf631344588ebf04df8ebf1fcb51c9c..4fdf4042a296e761546e7df8538bed980ea6549b 100644 (file)
@@ -18,6 +18,7 @@
 ###############################################################################
 
 import argparse
+import datetime
 import gettext
 import logging
 import logging.handlers
@@ -72,12 +73,6 @@ def _(singular, plural=None, n=None):
 
        return gettext.dgettext("libloc", singular)
 
-class NotModifiedError(Exception):
-       """
-               Raised when the file has not been modified on the server
-       """
-       pass
-
 
 class Downloader(object):
        def __init__(self, mirrors):
@@ -139,10 +134,6 @@ class Downloader(object):
                        for header in e.headers:
                                log.debug("             %s: %s" % (header, e.headers[header]))
 
-                       # Handle 304
-                       if e.code == 304:
-                               raise NotModifiedError() from e
-
                        # Raise all other errors
                        raise e
 
@@ -154,12 +145,12 @@ class Downloader(object):
 
                return res
 
-       def download(self, url, mtime=None, **kwargs):
+       def download(self, url, timestamp=None, **kwargs):
                headers = {}
 
-               if mtime:
-                       headers["If-Modified-Since"] = time.strftime(
-                               "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+               if timestamp:
+                       headers["If-Modified-Since"] = timestamp.strftime(
+                               "%a, %d %b %Y %H:%M:%S GMT",
                        )
 
                t = tempfile.NamedTemporaryFile(delete=False)
@@ -184,37 +175,59 @@ class Downloader(object):
                                                        if buf:
                                                                t.write(buf)
 
-                                               # Write all data to disk
-                                               t.flush()
-
-                               # Nothing to do when the database on the server is up to date
-                               except NotModifiedError:
-                                       log.info("Local database is up to date")
-                                       return
+                                       # Write all data to disk
+                                       t.flush()
 
                                # Catch decompression errors
                                except lzma.LZMAError as e:
                                        log.warning("Could not decompress downloaded file: %s" % e)
                                        continue
 
-                               # XXX what do we catch here?
                                except urllib.error.HTTPError as e:
-                                       if e.code == 404:
-                                               continue
+                                       # The file on the server was too old
+                                       if e.code == 304:
+                                               log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
 
-                                       # Truncate the target file and drop downloaded content
-                                       try:
-                                               t.truncate()
-                                       except OSError:
-                                               pass
+                                       # Log any other HTTP errors
+                                       else:
+                                               log.warning("%s reported: %s" % (mirror, e))
+
+                                       # Throw away any downloaded content and try again
+                                       t.truncate()
 
-                                       raise e
+                               else:
+                                       # Check if the downloaded database is recent
+                                       if not self._check_database(t, timestamp):
+                                               log.warning("Downloaded database is outdated. Trying next mirror...")
 
-                               # Return temporary file
-                               return t
+                                               # Throw away the data and try again
+                                               t.truncate()
+                                               continue
+
+                                       # Return temporary file
+                                       return t
 
                raise FileNotFoundError(url)
 
+       def _check_database(self, f, timestamp=None):
+               """
+                       Checks the downloaded database if it can be opened,
+                       verified and if it is recent enough
+               """
+               log.debug("Opening downloaded database at %s" % f.name)
+
+               db = location.Database(f.name)
+
+               # Database is not recent
+               if timestamp and db.created_at < timestamp.timestamp():
+                       return False
+
+               log.info("Downloaded new database from %s" % (time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               )))
+
+               return True
+
 
 class CLI(object):
        def __init__(self):
@@ -271,20 +284,30 @@ class CLI(object):
                sys.exit(0)
 
        def handle_update(self, ns):
-               mtime = None
+               # Fetch the version we need from DNS
+               t = location.discover_latest_version()
+
+               # Parse timestamp into datetime format
+               try:
+                       timestamp = datetime.datetime.fromtimestamp(t)
+               except:
+                       raise
 
                # Open database
                try:
                        db = location.Database(ns.database)
 
-                       # Get mtime of the old file
-                       mtime = os.path.getmtime(ns.database)
+                       # Check if we are already on the latest version
+                       if db.created_at >= timestamp.timestamp():
+                               log.info("Already on the latest version")
+                               return
+
                except FileNotFoundError as e:
                        db = None
 
                # Try downloading a new database
                try:
-                       t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+                       t = self.downloader.download(DATABASE_FILENAME, timestamp=timestamp)
 
                # If no file could be downloaded, log a message
                except FileNotFoundError as e:
@@ -295,24 +318,6 @@ class CLI(object):
                if not t:
                        return 0
 
-               # Save old database creation time
-               created_at = db.created_at if db else 0
-
-               # Try opening the downloaded file
-               try:
-                       db = location.Database(t.name)
-               except Exception as e:
-                       raise e
-
-               # Check if the downloaded file is newer
-               if db.created_at <= created_at:
-                       log.warning("Downloaded database is older than the current version")
-                       return 1
-
-               log.info("Downloaded new database from %s" % (time.strftime(
-                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
-               )))
-
                # Write temporary file to destination
                shutil.copyfile(t.name, ns.database)
 
index d4132609a2111f3c3260d6f3e692739accb6095c..78cf07539b60ee002f7e951b8f5cd3c7f550c360 100644 (file)
@@ -17,6 +17,8 @@
 #include <Python.h>
 #include <syslog.h>
 
+#include <loc/resolv.h>
+
 #include "locationmodule.h"
 #include "as.h"
 #include "country.h"
@@ -46,7 +48,30 @@ static PyObject* set_log_level(PyObject* m, PyObject* args) {
        Py_RETURN_NONE;
 }
 
+static PyObject* discover_latest_version(PyObject* m, PyObject* args) {
+       const char* domain = NULL;
+
+       if (!PyArg_ParseTuple(args, "|s", &domain))
+               return NULL;
+
+       time_t t = 0;
+
+       int r = loc_discover_latest_version(loc_ctx, domain, &t);
+       if (r) {
+               PyErr_SetFromErrno(PyExc_OSError);
+               return NULL;
+       }
+
+       return PyLong_FromUnsignedLong(t);
+}
+
 static PyMethodDef location_module_methods[] = {
+       {
+               "discover_latest_version",
+               (PyCFunction)discover_latest_version,
+               METH_VARARGS,
+               NULL,
+       },
        {
                "set_log_level",
                (PyCFunction)set_log_level,
diff --git a/src/resolv.c b/src/resolv.c
new file mode 100644 (file)
index 0000000..3ad0f27
--- /dev/null
@@ -0,0 +1,147 @@
+/*
+       libloc - A library to determine the location of someone on the Internet
+
+       Copyright (C) 2019 IPFire Development Team <info@ipfire.org>
+
+       This library is free software; you can redistribute it and/or
+       modify it under the terms of the GNU Lesser General Public
+       License as published by the Free Software Foundation; either
+       version 2.1 of the License, or (at your option) any later version.
+
+       This library is distributed in the hope that it will be useful,
+       but WITHOUT ANY WARRANTY; without even the implied warranty of
+       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+       Lesser General Public License for more details.
+*/
+
+#include <resolv.h>
+#include <string.h>
+#include <time.h>
+
+#include <loc/format.h>
+#include <loc/private.h>
+#include <loc/resolv.h>
+
+static int parse_timestamp(const unsigned char* txt, time_t* t) {
+    struct tm ts;
+
+    // Parse timestamp
+    char* p = strptime((const char*)txt, "%a, %d %b %Y %H:%M:%S GMT", &ts);
+
+    // If the whole string has been parsed, we convert the parse value to time_t
+    if (p && !*p) {
+        *t = mktime(&ts);
+
+    // Otherwise we reset t
+    } else {
+        *t = 0;
+        return -1;
+    }
+
+    return 0;
+}
+
+LOC_EXPORT int loc_discover_latest_version(struct loc_ctx* ctx, const char* domain, time_t* t) {
+    // Initialise the resolver
+    int r = res_init();
+    if (r) {
+        ERROR(ctx, "res_init() failed\n");
+        return r;
+    }
+
+    // Fall back to default domain
+    if (!domain)
+        domain = LOC_DATABASE_DOMAIN_LATEST(LOC_DATABASE_VERSION);
+
+    unsigned char answer[PACKETSZ];
+    int len;
+
+    DEBUG(ctx, "Querying %s\n", domain);
+
+    // Send a query
+    if ((len = res_query(domain, C_IN, T_TXT, answer, sizeof(answer))) < 0 || len > PACKETSZ) {
+        ERROR(ctx, "Could not query %s: \n", domain);
+
+        return -1;
+    }
+
+    unsigned char* end = answer + len;
+    unsigned char* payload = answer + sizeof(HEADER);
+
+    // Expand domain name
+    char host[128];
+    if ((len = dn_expand(answer, end, payload, host, sizeof(host))) < 0) {
+        ERROR(ctx, "dn_expand() failed\n");
+        return -1;
+    }
+
+    // Payload starts after hostname
+    payload += len;
+
+    if (payload > end - 4) {
+        ERROR(ctx, "DNS reply too short\n");
+        return -1;
+    }
+
+    int type;
+    GETSHORT(type, payload);
+    if (type != T_TXT) {
+        ERROR(ctx, "DNS reply of unexpected type: %d\n", type);
+        return -1;
+    }
+
+    // Skip class
+    payload += INT16SZ;
+
+    // Walk through CNAMEs
+    unsigned int size = 0;
+    int ttl;
+    do {
+        payload += size;
+
+        if ((len = dn_expand(answer, end, payload, host, sizeof(host))) < 0) {
+            ERROR(ctx, "dn_expand() failed\n");
+            return -1;
+        }
+
+        payload += len;
+
+        if (payload > end - 10) {
+            ERROR(ctx, "DNS reply too short\n");
+            return -1;
+        }
+
+        // Skip type, class, ttl
+        GETSHORT(type, payload);
+        payload += INT16SZ;
+        GETLONG(ttl, payload);
+
+        // Read size
+        GETSHORT(size, payload);
+        if (payload + size < answer || payload + size > end) {
+            ERROR(ctx, "DNS RR overflow\n");
+            return -1;
+        }
+    } while (type == T_CNAME);
+
+    if (type != T_TXT) {
+        ERROR(ctx, "Not a TXT record\n");
+        return -1;
+    }
+
+    if (!size || (len = *payload) >= size || !len) {
+        ERROR(ctx, "Broken TXT record (len = %d, size = %d)\n", len, size);
+        return -1;
+    }
+
+    // Get start of the string
+    unsigned char* txt = payload + 1;
+    txt[len] = '\0';
+
+    DEBUG(ctx, "Resolved to: %s\n", txt);
+
+    // Parse timestamp
+    r = parse_timestamp(txt, t);
+
+    return r;
+}