]> git.ipfire.org Git - location/libloc.git/blobdiff - src/python/location-downloader.in
python: Correctly set log level for root logger
[location/libloc.git] / src / python / location-downloader.in
index 9b19de479589577e5113a30b2b8d9ff98a3794c0..f1d38f1418a87c803d78ee18beea13932aef5fae 100644 (file)
@@ -18,7 +18,8 @@
 ###############################################################################
 
 import argparse
-import gettext
+import datetime
+import logging
 import lzma
 import os
 import random
@@ -32,9 +33,7 @@ import urllib.request
 
 # Load our location module
 import location
-
-import logging
-logging.basicConfig(level=logging.INFO)
+from location.i18n import _
 
 DATABASE_FILENAME = "test.db.xz"
 MIRRORS = (
@@ -42,19 +41,9 @@ MIRRORS = (
        "https://people.ipfire.org/~ms/location/",
 )
 
-# i18n
-def _(singular, plural=None, n=None):
-       if plural:
-               return gettext.dngettext("libloc", singular, plural, n)
-
-       return gettext.dgettext("libloc", singular)
-
-class NotModifiedError(Exception):
-       """
-               Raised when the file has not been modified on the server
-       """
-       pass
-
+# Initialise logging
+log = logging.getLogger("location.downloader")
+log.propagate = 1
 
 class Downloader(object):
        def __init__(self, mirrors):
@@ -85,7 +74,7 @@ class Downloader(object):
 
                # Update headers
                headers.update({
-                       "User-Agent" : "location-downloader/%s" % location.__version__,
+                       "User-Agent" : "location-downloader/@VERSION@",
                })
 
                # Set headers
@@ -100,43 +89,39 @@ class Downloader(object):
 
        def _send_request(self, req, **kwargs):
                # Log request headers
-               logging.debug("HTTP %s Request to %s" % (req.method, req.host))
-               logging.debug(" URL: %s" % req.full_url)
-               logging.debug(" Headers:")
+               log.debug("HTTP %s Request to %s" % (req.method, req.host))
+               log.debug("     URL: %s" % req.full_url)
+               log.debug("     Headers:")
                for k, v in req.header_items():
-                       logging.debug("         %s: %s" % (k, v))
+                       log.debug("             %s: %s" % (k, v))
 
                try:
                        res = urllib.request.urlopen(req, **kwargs)
 
                except urllib.error.HTTPError as e:
                        # Log response headers
-                       logging.debug("HTTP Response: %s" % e.code)
-                       logging.debug(" Headers:")
+                       log.debug("HTTP Response: %s" % e.code)
+                       log.debug("     Headers:")
                        for header in e.headers:
-                               logging.debug("         %s: %s" % (header, e.headers[header]))
-
-                       # Handle 304
-                       if e.code == 304:
-                               raise NotModifiedError() from e
+                               log.debug("             %s: %s" % (header, e.headers[header]))
 
                        # Raise all other errors
                        raise e
 
                # Log response headers
-               logging.debug("HTTP Response: %s" % res.code)
-               logging.debug(" Headers:")
+               log.debug("HTTP Response: %s" % res.code)
+               log.debug("     Headers:")
                for k, v in res.getheaders():
-                       logging.debug("         %s: %s" % (k, v))
+                       log.debug("             %s: %s" % (k, v))
 
                return res
 
-       def download(self, url, mtime=None, **kwargs):
+       def download(self, url, public_key, timestamp=None, **kwargs):
                headers = {}
 
-               if mtime:
-                       headers["If-Modified-Since"] = time.strftime(
-                               "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+               if timestamp:
+                       headers["If-Modified-Since"] = timestamp.strftime(
+                               "%a, %d %b %Y %H:%M:%S GMT",
                        )
 
                t = tempfile.NamedTemporaryFile(delete=False)
@@ -161,37 +146,65 @@ class Downloader(object):
                                                        if buf:
                                                                t.write(buf)
 
-                                               # Write all data to disk
-                                               t.flush()
-
-                               # Nothing to do when the database on the server is up to date
-                               except NotModifiedError:
-                                       logging.info("Local database is up to date")
-                                       return
+                                       # Write all data to disk
+                                       t.flush()
 
                                # Catch decompression errors
                                except lzma.LZMAError as e:
-                                       logging.warning("Could not decompress downloaded file: %s" % e)
+                                       log.warning("Could not decompress downloaded file: %s" % e)
                                        continue
 
-                               # XXX what do we catch here?
                                except urllib.error.HTTPError as e:
-                                       if e.code == 404:
-                                               continue
+                                       # The file on the server was too old
+                                       if e.code == 304:
+                                               log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
 
-                                       # Truncate the target file and drop downloaded content
-                                       try:
-                                               t.truncate()
-                                       except OSError:
-                                               pass
+                                       # Log any other HTTP errors
+                                       else:
+                                               log.warning("%s reported: %s" % (mirror, e))
 
-                                       raise e
+                                       # Throw away any downloaded content and try again
+                                       t.truncate()
 
-                               # Return temporary file
-                               return t
+                               else:
+                                       # Check if the downloaded database is recent
+                                       if not self._check_database(t, public_key, timestamp):
+                                               log.warning("Downloaded database is outdated. Trying next mirror...")
+
+                                               # Throw away the data and try again
+                                               t.truncate()
+                                               continue
+
+                                       # Return temporary file
+                                       return t
 
                raise FileNotFoundError(url)
 
+       def _check_database(self, f, public_key, timestamp=None):
+               """
+                       Checks the downloaded database if it can be opened,
+                       verified and if it is recent enough
+               """
+               log.debug("Opening downloaded database at %s" % f.name)
+
+               db = location.Database(f.name)
+
+               # Database is not recent
+               if timestamp and db.created_at < timestamp.timestamp():
+                       return False
+
+               log.info("Downloaded new database from %s" % (time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               )))
+
+               # Verify the database
+               with open(public_key, "r") as f:
+                       if not db.verify(f):
+                               log.error("Could not verify database")
+                               return False
+
+               return True
+
 
 class CLI(object):
        def __init__(self):
@@ -209,22 +222,32 @@ class CLI(object):
 
                # version
                parser.add_argument("--version", action="version",
-                       version="%%(prog)s %s" % location.__version__)
+                       version="%(prog)s @VERSION@")
 
                # database
                parser.add_argument("--database", "-d",
                        default="@databasedir@/database.db", help=_("Path to database"),
                )
 
+               # public key
+               parser.add_argument("--public-key", "-k",
+                       default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
+               )
+
                # Update
                update = subparsers.add_parser("update", help=_("Update database"))
                update.set_defaults(func=self.handle_update)
 
+               # Verify
+               verify = subparsers.add_parser("verify",
+                       help=_("Verify the downloaded database"))
+               verify.set_defaults(func=self.handle_verify)
+
                args = parser.parse_args()
 
                # Enable debug logging
                if args.debug:
-                       logging.basicConfig(level=logging.DEBUG)
+                       location.logger.set_level(logging.DEBUG)
 
                # Print usage if no action was given
                if not "func" in args:
@@ -248,47 +271,40 @@ class CLI(object):
                sys.exit(0)
 
        def handle_update(self, ns):
-               mtime = None
+               # Fetch the version we need from DNS
+               t = location.discover_latest_version()
+
+               # Parse timestamp into datetime format
+               try:
+                       timestamp = datetime.datetime.fromtimestamp(t)
+               except:
+                       raise
 
                # Open database
                try:
                        db = location.Database(ns.database)
 
-                       # Get mtime of the old file
-                       mtime = os.path.getmtime(ns.database)
+                       # Check if we are already on the latest version
+                       if db.created_at >= timestamp.timestamp():
+                               log.info("Already on the latest version")
+                               return
+
                except FileNotFoundError as e:
                        db = None
 
                # Try downloading a new database
                try:
-                       t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+                       t = self.downloader.download(DATABASE_FILENAME,
+                               public_key=ns.public_key, timestamp=timestamp)
 
                # If no file could be downloaded, log a message
                except FileNotFoundError as e:
-                       logging.error("Could not download a new database")
+                       log.error("Could not download a new database")
                        return 1
 
                # If we have not received a new file, there is nothing to do
                if not t:
-                       return 0
-
-               # Save old database creation time
-               created_at = db.created_at if db else 0
-
-               # Try opening the downloaded file
-               try:
-                       db = location.Database(t.name)
-               except Exception as e:
-                       raise e
-
-               # Check if the downloaded file is newer
-               if db.created_at <= created_at:
-                       logging.warning("Downloaded database is older than the current version")
-                       return 1
-
-               logging.info("Downloaded new database from %s" % (time.strftime(
-                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
-               )))
+                       return 3
 
                # Write temporary file to destination
                shutil.copyfile(t.name, ns.database)
@@ -296,6 +312,25 @@ class CLI(object):
                # Remove temporary file
                os.unlink(t.name)
 
+               return 0
+
+       def handle_verify(self, ns):
+               try:
+                       db = location.Database(ns.database)
+               except FileNotFoundError as e:
+                       log.error("%s: %s" % (ns.database, e))
+                       return 127
+
+               # Verify the database
+               with open(ns.public_key, "r") as f:
+                       if not db.verify(f):
+                               log.error("Could not verify database")
+                               return 1
+
+               # Success
+               log.debug("Database successfully verified")
+               return 0
+
 
 def main():
        # Run the command line interface