]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/python/location-downloader.in
downloader: Check DNS for most recent version
[people/ms/libloc.git] / src / python / location-downloader.in
index 961c5dffdaf631344588ebf04df8ebf1fcb51c9c..4fdf4042a296e761546e7df8538bed980ea6549b 100644 (file)
@@ -18,6 +18,7 @@
 ###############################################################################
 
 import argparse
+import datetime
 import gettext
 import logging
 import logging.handlers
@@ -72,12 +73,6 @@ def _(singular, plural=None, n=None):
 
        return gettext.dgettext("libloc", singular)
 
-class NotModifiedError(Exception):
-       """
-               Raised when the file has not been modified on the server
-       """
-       pass
-
 
 class Downloader(object):
        def __init__(self, mirrors):
@@ -139,10 +134,6 @@ class Downloader(object):
                        for header in e.headers:
                                log.debug("             %s: %s" % (header, e.headers[header]))
 
-                       # Handle 304
-                       if e.code == 304:
-                               raise NotModifiedError() from e
-
                        # Raise all other errors
                        raise e
 
@@ -154,12 +145,12 @@ class Downloader(object):
 
                return res
 
-       def download(self, url, mtime=None, **kwargs):
+       def download(self, url, timestamp=None, **kwargs):
                headers = {}
 
-               if mtime:
-                       headers["If-Modified-Since"] = time.strftime(
-                               "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+               if timestamp:
+                       headers["If-Modified-Since"] = timestamp.strftime(
+                               "%a, %d %b %Y %H:%M:%S GMT",
                        )
 
                t = tempfile.NamedTemporaryFile(delete=False)
@@ -184,37 +175,59 @@ class Downloader(object):
                                                        if buf:
                                                                t.write(buf)
 
-                                               # Write all data to disk
-                                               t.flush()
-
-                               # Nothing to do when the database on the server is up to date
-                               except NotModifiedError:
-                                       log.info("Local database is up to date")
-                                       return
+                                       # Write all data to disk
+                                       t.flush()
 
                                # Catch decompression errors
                                except lzma.LZMAError as e:
                                        log.warning("Could not decompress downloaded file: %s" % e)
                                        continue
 
-                               # XXX what do we catch here?
                                except urllib.error.HTTPError as e:
-                                       if e.code == 404:
-                                               continue
+                                       # The file on the server was too old
+                                       if e.code == 304:
+                                               log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
 
-                                       # Truncate the target file and drop downloaded content
-                                       try:
-                                               t.truncate()
-                                       except OSError:
-                                               pass
+                                       # Log any other HTTP errors
+                                       else:
+                                               log.warning("%s reported: %s" % (mirror, e))
+
+                                       # Throw away any downloaded content and try again
+                                       t.truncate()
 
-                                       raise e
+                               else:
+                                       # Check if the downloaded database is recent
+                                       if not self._check_database(t, timestamp):
+                                               log.warning("Downloaded database is outdated. Trying next mirror...")
 
-                               # Return temporary file
-                               return t
+                                               # Throw away the data and try again
+                                               t.truncate()
+                                               continue
+
+                                       # Return temporary file
+                                       return t
 
                raise FileNotFoundError(url)
 
+       def _check_database(self, f, timestamp=None):
+               """
+                       Checks the downloaded database if it can be opened,
+                       verified and if it is recent enough
+               """
+               log.debug("Opening downloaded database at %s" % f.name)
+
+               db = location.Database(f.name)
+
+               # Database is not recent
+               if timestamp and db.created_at < timestamp.timestamp():
+                       return False
+
+               log.info("Downloaded new database from %s" % (time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               )))
+
+               return True
+
 
 class CLI(object):
        def __init__(self):
@@ -271,20 +284,30 @@ class CLI(object):
                sys.exit(0)
 
        def handle_update(self, ns):
-               mtime = None
+               # Fetch the version we need from DNS
+               t = location.discover_latest_version()
+
+               # Parse timestamp into datetime format
+               try:
+                       timestamp = datetime.datetime.fromtimestamp(t)
+               except:
+                       raise
 
                # Open database
                try:
                        db = location.Database(ns.database)
 
-                       # Get mtime of the old file
-                       mtime = os.path.getmtime(ns.database)
+                       # Check if we are already on the latest version
+                       if db.created_at >= timestamp.timestamp():
+                               log.info("Already on the latest version")
+                               return
+
                except FileNotFoundError as e:
                        db = None
 
                # Try downloading a new database
                try:
-                       t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+                       t = self.downloader.download(DATABASE_FILENAME, timestamp=timestamp)
 
                # If no file could be downloaded, log a message
                except FileNotFoundError as e:
@@ -295,24 +318,6 @@ class CLI(object):
                if not t:
                        return 0
 
-               # Save old database creation time
-               created_at = db.created_at if db else 0
-
-               # Try opening the downloaded file
-               try:
-                       db = location.Database(t.name)
-               except Exception as e:
-                       raise e
-
-               # Check if the downloaded file is newer
-               if db.created_at <= created_at:
-                       log.warning("Downloaded database is older than the current version")
-                       return 1
-
-               log.info("Downloaded new database from %s" % (time.strftime(
-                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
-               )))
-
                # Write temporary file to destination
                shutil.copyfile(t.name, ns.database)