--- /dev/null
+#!/usr/bin/python3
+###############################################################################
+# #
+# libloc - A library to determine the location of someone on the Internet #
+# #
+# Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
+# #
+# This library is free software; you can redistribute it and/or #
+# modify it under the terms of the GNU Lesser General Public #
+# License as published by the Free Software Foundation; either #
+# version 2.1 of the License, or (at your option) any later version. #
+# #
+# This library is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
+# Lesser General Public License for more details. #
+# #
+###############################################################################
+
+import argparse
+import gettext
+import lzma
+import os
+import random
+import shutil
+import sys
+import tempfile
+import time
+import urllib.error
+import urllib.parse
+import urllib.request
+
+# Load our location module
+import location
+
+import logging
+logging.basicConfig(level=logging.INFO)
+
+DATABASE_FILENAME = "test.db.xz"
+MIRRORS = (
+ "https://location.ipfire.org/databases/",
+ "https://people.ipfire.org/~ms/location/",
+)
+
+# i18n
+def _(singular, plural=None, n=None):
+ if plural:
+ return gettext.dngettext("libloc", singular, plural, n)
+
+ return gettext.dgettext("libloc", singular)
+
+class NotModifiedError(Exception):
+ """
+ Raised when the file has not been modified on the server
+ """
+ pass
+
+
+class Downloader(object):
+ def __init__(self, mirrors):
+ self.mirrors = list(mirrors)
+
+ # Randomize mirrors
+ random.shuffle(self.mirrors)
+
+ # Get proxies from environment
+ self.proxies = self._get_proxies()
+
+ def _get_proxies(self):
+ proxies = {}
+
+ for protocol in ("https", "http"):
+ proxy = os.environ.get("%s_proxy" % protocol, None)
+
+ if proxy:
+ proxies[protocol] = proxy
+
+ return proxies
+
+ def _make_request(self, url, baseurl=None, headers={}):
+ if baseurl:
+ url = urllib.parse.urljoin(baseurl, url)
+
+ req = urllib.request.Request(url, method="GET")
+
+ # Update headers
+ headers.update({
+ "User-Agent" : "location-downloader/%s" % location.__version__,
+ })
+
+ # Set headers
+ for header in headers:
+ req.add_header(header, headers[header])
+
+ # Set proxies
+ for protocol in self.proxies:
+ req.set_proxy(self.proxies[protocol], protocol)
+
+ return req
+
+ def _send_request(self, req, **kwargs):
+ # Log request headers
+ logging.debug("HTTP %s Request to %s" % (req.method, req.host))
+ logging.debug(" URL: %s" % req.full_url)
+ logging.debug(" Headers:")
+ for k, v in req.header_items():
+ logging.debug(" %s: %s" % (k, v))
+
+ try:
+ res = urllib.request.urlopen(req, **kwargs)
+
+ except urllib.error.HTTPError as e:
+ # Log response headers
+ logging.debug("HTTP Response: %s" % e.code)
+ logging.debug(" Headers:")
+ for header in e.headers:
+ logging.debug(" %s: %s" % (header, e.headers[header]))
+
+ # Handle 304
+ if e.code == 304:
+ raise NotModifiedError() from e
+
+ # Raise all other errors
+ raise e
+
+ # Log response headers
+ logging.debug("HTTP Response: %s" % res.code)
+ logging.debug(" Headers:")
+ for k, v in res.getheaders():
+ logging.debug(" %s: %s" % (k, v))
+
+ return res
+
+ def download(self, url, mtime=None, **kwargs):
+ headers = {}
+
+ if mtime:
+ headers["If-Modified-Since"] = time.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+ )
+
+ t = tempfile.NamedTemporaryFile(delete=False)
+ with t:
+ # Try all mirrors
+ for mirror in self.mirrors:
+ # Prepare HTTP request
+ req = self._make_request(url, baseurl=mirror, headers=headers)
+
+ try:
+ with self._send_request(req) as res:
+ decompressor = lzma.LZMADecompressor()
+
+ # Read all data
+ while True:
+ buf = res.read(1024)
+ if not buf:
+ break
+
+ # Decompress data
+ buf = decompressor.decompress(buf)
+ if buf:
+ t.write(buf)
+
+ # Write all data to disk
+ t.flush()
+
+ # Nothing to do when the database on the server is up to date
+ except NotModifiedError:
+ logging.info("Local database is up to date")
+ return
+
+ # Catch decompression errors
+ except lzma.LZMAError as e:
+ logging.warning("Could not decompress downloaded file: %s" % e)
+ continue
+
+ # XXX what do we catch here?
+ except urllib.error.HTTPError as e:
+ if e.code == 404:
+ continue
+
+ # Truncate the target file and drop downloaded content
+ try:
+ t.truncate()
+ except OSError:
+ pass
+
+ raise e
+
+ # Return temporary file
+ return t
+
+ raise FileNotFoundError(url)
+
+
+class CLI(object):
+ def __init__(self):
+ self.downloader = Downloader(mirrors=MIRRORS)
+
+ def parse_cli(self):
+ parser = argparse.ArgumentParser(
+ description=_("Location Downloader Command Line Interface"),
+ )
+ subparsers = parser.add_subparsers()
+
+ # Global configuration flags
+ parser.add_argument("--debug", action="store_true",
+ help=_("Enable debug output"))
+
+ # version
+ parser.add_argument("--version", action="version",
+ version="%%(prog)s %s" % location.__version__)
+
+ # database
+ parser.add_argument("--database", "-d",
+ default="@databasedir@/database.db", help=_("Path to database"),
+ )
+
+ # Update
+ update = subparsers.add_parser("update", help=_("Update database"))
+ update.set_defaults(func=self.handle_update)
+
+ args = parser.parse_args()
+
+ # Enable debug logging
+ if args.debug:
+ logging.basicConfig(level=logging.DEBUG)
+
+ # Print usage if no action was given
+ if not "func" in args:
+ parser.print_usage()
+ sys.exit(2)
+
+ return args
+
+ def run(self):
+ # Parse command line arguments
+ args = self.parse_cli()
+
+ # Call function
+ ret = args.func(args)
+
+ # Return with exit code
+ if ret:
+ sys.exit(ret)
+
+ # Otherwise just exit
+ sys.exit(0)
+
+ def handle_update(self, ns):
+ mtime = None
+
+ # Open database
+ try:
+ db = location.Database(ns.database)
+
+ # Get mtime of the old file
+ mtime = os.path.getmtime(ns.database)
+ except FileNotFoundError as e:
+ db = None
+
+ # Try downloading a new database
+ try:
+ t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+
+ # If no file could be downloaded, log a message
+ except FileNotFoundError as e:
+ logging.error("Could not download a new database")
+ return 1
+
+ # If we have not received a new file, there is nothing to do
+ if not t:
+ return 0
+
+ # Save old database creation time
+ created_at = db.created_at if db else 0
+
+ # Try opening the downloaded file
+ try:
+ db = location.Database(t.name)
+ except Exception as e:
+ raise e
+
+ # Check if the downloaded file is newer
+ if db.created_at <= created_at:
+ logging.warning("Downloaded database is older than the current version")
+ return 1
+
+ logging.info("Downloaded new database from %s" % (time.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+ )))
+
+ # Write temporary file to destination
+ shutil.copyfile(t.name, ns.database)
+
+ # Remove temporary file
+ os.unlink(t.name)
+
+
+def main():
+ # Run the command line interface
+ c = CLI()
+ c.run()
+
+main()