]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
Add download script to automatically update the database
authorMichael Tremer <michael.tremer@ipfire.org>
Sun, 17 Nov 2019 13:45:39 +0000 (13:45 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sun, 17 Nov 2019 13:46:03 +0000 (13:46 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
.gitignore
Makefile.am
src/python/location-downloader.in [new file with mode: 0644]

index a2dc6b14335f63fab2babdcdad0306171786bab4..19e05bdf62f12d7cc48e4828e6fddc67db157560 100644 (file)
@@ -12,6 +12,7 @@ Makefile.in
 /configure
 /libtool
 /stamp-h1
+/src/python/location-downloader
 /src/python/location-query
 /test.db
 /testdata.db
index 162956a6e33be015c66db4f8a566ac616cdda7b6..33a7e9804777d2bcf02d7ef1444882367096861a 100644 (file)
@@ -210,12 +210,15 @@ uninstall-perl:
                $(DESTDIR)/$(prefix)/man/man3/Location.3pm
 
 bin_SCRIPTS = \
+       src/python/location-downloader \
        src/python/location-query
 
 EXTRA_DIST += \
+       src/python/location-downloader.in \
        src/python/location-query.in
 
 CLEANFILES += \
+       src/python/location-downloader \
        src/python/location-query
 
 # ------------------------------------------------------------------------------
diff --git a/src/python/location-downloader.in b/src/python/location-downloader.in
new file mode 100644 (file)
index 0000000..9b19de4
--- /dev/null
@@ -0,0 +1,305 @@
+#!/usr/bin/python3
+###############################################################################
+#                                                                             #
+# libloc - A library to determine the location of someone on the Internet     #
+#                                                                             #
+# Copyright (C) 2019 IPFire Development Team <info@ipfire.org>                #
+#                                                                             #
+# This library is free software; you can redistribute it and/or               #
+# modify it under the terms of the GNU Lesser General Public                  #
+# License as published by the Free Software Foundation; either                #
+# version 2.1 of the License, or (at your option) any later version.          #
+#                                                                             #
+# This library is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU           #
+# Lesser General Public License for more details.                             #
+#                                                                             #
+###############################################################################
+
+import argparse
+import gettext
+import lzma
+import os
+import random
+import shutil
+import sys
+import tempfile
+import time
+import urllib.error
+import urllib.parse
+import urllib.request
+
+# Load our location module
+import location
+
+import logging
+logging.basicConfig(level=logging.INFO)
+
+DATABASE_FILENAME = "test.db.xz"
+MIRRORS = (
+       "https://location.ipfire.org/databases/",
+       "https://people.ipfire.org/~ms/location/",
+)
+
+# i18n
+def _(singular, plural=None, n=None):
+       if plural:
+               return gettext.dngettext("libloc", singular, plural, n)
+
+       return gettext.dgettext("libloc", singular)
+
+class NotModifiedError(Exception):
+       """
+               Raised when the file has not been modified on the server
+       """
+       pass
+
+
+class Downloader(object):
+       def __init__(self, mirrors):
+               self.mirrors = list(mirrors)
+
+               # Randomize mirrors
+               random.shuffle(self.mirrors)
+
+               # Get proxies from environment
+               self.proxies = self._get_proxies()
+
+       def _get_proxies(self):
+               proxies = {}
+
+               for protocol in ("https", "http"):
+                       proxy = os.environ.get("%s_proxy" % protocol, None)
+
+                       if proxy:
+                               proxies[protocol] = proxy
+
+               return proxies
+
+       def _make_request(self, url, baseurl=None, headers={}):
+               if baseurl:
+                       url = urllib.parse.urljoin(baseurl, url)
+
+               req = urllib.request.Request(url, method="GET")
+
+               # Update headers
+               headers.update({
+                       "User-Agent" : "location-downloader/%s" % location.__version__,
+               })
+
+               # Set headers
+               for header in headers:
+                       req.add_header(header, headers[header])
+
+               # Set proxies
+               for protocol in self.proxies:
+                       req.set_proxy(self.proxies[protocol], protocol)
+
+               return req
+
+       def _send_request(self, req, **kwargs):
+               # Log request headers
+               logging.debug("HTTP %s Request to %s" % (req.method, req.host))
+               logging.debug(" URL: %s" % req.full_url)
+               logging.debug(" Headers:")
+               for k, v in req.header_items():
+                       logging.debug("         %s: %s" % (k, v))
+
+               try:
+                       res = urllib.request.urlopen(req, **kwargs)
+
+               except urllib.error.HTTPError as e:
+                       # Log response headers
+                       logging.debug("HTTP Response: %s" % e.code)
+                       logging.debug(" Headers:")
+                       for header in e.headers:
+                               logging.debug("         %s: %s" % (header, e.headers[header]))
+
+                       # Handle 304
+                       if e.code == 304:
+                               raise NotModifiedError() from e
+
+                       # Raise all other errors
+                       raise e
+
+               # Log response headers
+               logging.debug("HTTP Response: %s" % res.code)
+               logging.debug(" Headers:")
+               for k, v in res.getheaders():
+                       logging.debug("         %s: %s" % (k, v))
+
+               return res
+
+       def download(self, url, mtime=None, **kwargs):
+               headers = {}
+
+               if mtime:
+                       headers["If-Modified-Since"] = time.strftime(
+                               "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+                       )
+
+               t = tempfile.NamedTemporaryFile(delete=False)
+               with t:
+                       # Try all mirrors
+                       for mirror in self.mirrors:
+                               # Prepare HTTP request
+                               req = self._make_request(url, baseurl=mirror, headers=headers)
+
+                               try:
+                                       with self._send_request(req) as res:
+                                               decompressor = lzma.LZMADecompressor()
+
+                                               # Read all data
+                                               while True:
+                                                       buf = res.read(1024)
+                                                       if not buf:
+                                                               break
+
+                                                       # Decompress data
+                                                       buf = decompressor.decompress(buf)
+                                                       if buf:
+                                                               t.write(buf)
+
+                                               # Write all data to disk
+                                               t.flush()
+
+                               # Nothing to do when the database on the server is up to date
+                               except NotModifiedError:
+                                       logging.info("Local database is up to date")
+                                       return
+
+                               # Catch decompression errors
+                               except lzma.LZMAError as e:
+                                       logging.warning("Could not decompress downloaded file: %s" % e)
+                                       continue
+
+                               # XXX what do we catch here?
+                               except urllib.error.HTTPError as e:
+                                       if e.code == 404:
+                                               continue
+
+                                       # Truncate the target file and drop downloaded content
+                                       try:
+                                               t.truncate()
+                                       except OSError:
+                                               pass
+
+                                       raise e
+
+                               # Return temporary file
+                               return t
+
+               raise FileNotFoundError(url)
+
+
+class CLI(object):
+       def __init__(self):
+               self.downloader = Downloader(mirrors=MIRRORS)
+
+       def parse_cli(self):
+               parser = argparse.ArgumentParser(
+                       description=_("Location Downloader Command Line Interface"),
+               )
+               subparsers = parser.add_subparsers()
+
+               # Global configuration flags
+               parser.add_argument("--debug", action="store_true",
+                       help=_("Enable debug output"))
+
+               # version
+               parser.add_argument("--version", action="version",
+                       version="%%(prog)s %s" % location.__version__)
+
+               # database
+               parser.add_argument("--database", "-d",
+                       default="@databasedir@/database.db", help=_("Path to database"),
+               )
+
+               # Update
+               update = subparsers.add_parser("update", help=_("Update database"))
+               update.set_defaults(func=self.handle_update)
+
+               args = parser.parse_args()
+
+               # Enable debug logging
+               if args.debug:
+                       logging.basicConfig(level=logging.DEBUG)
+
+               # Print usage if no action was given
+               if not "func" in args:
+                       parser.print_usage()
+                       sys.exit(2)
+
+               return args
+
+       def run(self):
+               # Parse command line arguments
+               args = self.parse_cli()
+
+               # Call function
+               ret = args.func(args)
+
+               # Return with exit code
+               if ret:
+                       sys.exit(ret)
+
+               # Otherwise just exit
+               sys.exit(0)
+
+       def handle_update(self, ns):
+               mtime = None
+
+               # Open database
+               try:
+                       db = location.Database(ns.database)
+
+                       # Get mtime of the old file
+                       mtime = os.path.getmtime(ns.database)
+               except FileNotFoundError as e:
+                       db = None
+
+               # Try downloading a new database
+               try:
+                       t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+
+               # If no file could be downloaded, log a message
+               except FileNotFoundError as e:
+                       logging.error("Could not download a new database")
+                       return 1
+
+               # If we have not received a new file, there is nothing to do
+               if not t:
+                       return 0
+
+               # Save old database creation time
+               created_at = db.created_at if db else 0
+
+               # Try opening the downloaded file
+               try:
+                       db = location.Database(t.name)
+               except Exception as e:
+                       raise e
+
+               # Check if the downloaded file is newer
+               if db.created_at <= created_at:
+                       logging.warning("Downloaded database is older than the current version")
+                       return 1
+
+               logging.info("Downloaded new database from %s" % (time.strftime(
+                       "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+               )))
+
+               # Write temporary file to destination
+               shutil.copyfile(t.name, ns.database)
+
+               # Remove temporary file
+               os.unlink(t.name)
+
+
+def main():
+       # Run the command line interface
+       c = CLI()
+       c.run()
+
+main()