#!/usr/bin/python3 ############################################################################### # # # libloc - A library to determine the location of someone on the Internet # # # # Copyright (C) 2019 IPFire Development Team # # # # This library is free software; you can redistribute it and/or # # modify it under the terms of the GNU Lesser General Public # # License as published by the Free Software Foundation; either # # version 2.1 of the License, or (at your option) any later version. # # # # This library is distributed in the hope that it will be useful, # # but WITHOUT ANY WARRANTY; without even the implied warranty of # # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # # Lesser General Public License for more details. # # # ############################################################################### import argparse import gettext import logging import logging.handlers import lzma import os import random import shutil import sys import tempfile import time import urllib.error import urllib.parse import urllib.request # Load our location module import location DATABASE_FILENAME = "test.db.xz" MIRRORS = ( "https://location.ipfire.org/databases/", "https://people.ipfire.org/~ms/location/", ) def setup_logging(level=logging.INFO): l = logging.getLogger("location-downloader") l.setLevel(level) # Log to console h = logging.StreamHandler() h.setLevel(logging.DEBUG) l.addHandler(h) # Log to syslog h = logging.handlers.SysLogHandler(address="/dev/log", facility=logging.handlers.SysLogHandler.LOG_DAEMON) h.setLevel(logging.INFO) l.addHandler(h) # Format syslog messages formatter = logging.Formatter("location-downloader[%(process)d]: %(message)s") h.setFormatter(formatter) return l # Initialise logging log = setup_logging() # i18n def _(singular, plural=None, n=None): if plural: return gettext.dngettext("libloc", singular, plural, n) return gettext.dgettext("libloc", singular) class NotModifiedError(Exception): """ Raised when the file has not been modified on the server """ pass class Downloader(object): def __init__(self, mirrors): self.mirrors = list(mirrors) # Randomize mirrors random.shuffle(self.mirrors) # Get proxies from environment self.proxies = self._get_proxies() def _get_proxies(self): proxies = {} for protocol in ("https", "http"): proxy = os.environ.get("%s_proxy" % protocol, None) if proxy: proxies[protocol] = proxy return proxies def _make_request(self, url, baseurl=None, headers={}): if baseurl: url = urllib.parse.urljoin(baseurl, url) req = urllib.request.Request(url, method="GET") # Update headers headers.update({ "User-Agent" : "location-downloader/%s" % location.__version__, }) # Set headers for header in headers: req.add_header(header, headers[header]) # Set proxies for protocol in self.proxies: req.set_proxy(self.proxies[protocol], protocol) return req def _send_request(self, req, **kwargs): # Log request headers log.debug("HTTP %s Request to %s" % (req.method, req.host)) log.debug(" URL: %s" % req.full_url) log.debug(" Headers:") for k, v in req.header_items(): log.debug(" %s: %s" % (k, v)) try: res = urllib.request.urlopen(req, **kwargs) except urllib.error.HTTPError as e: # Log response headers log.debug("HTTP Response: %s" % e.code) log.debug(" Headers:") for header in e.headers: log.debug(" %s: %s" % (header, e.headers[header])) # Handle 304 if e.code == 304: raise NotModifiedError() from e # Raise all other errors raise e # Log response headers log.debug("HTTP Response: %s" % res.code) log.debug(" Headers:") for k, v in res.getheaders(): log.debug(" %s: %s" % (k, v)) return res def download(self, url, mtime=None, **kwargs): headers = {} if mtime: headers["If-Modified-Since"] = time.strftime( "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime), ) t = tempfile.NamedTemporaryFile(delete=False) with t: # Try all mirrors for mirror in self.mirrors: # Prepare HTTP request req = self._make_request(url, baseurl=mirror, headers=headers) try: with self._send_request(req) as res: decompressor = lzma.LZMADecompressor() # Read all data while True: buf = res.read(1024) if not buf: break # Decompress data buf = decompressor.decompress(buf) if buf: t.write(buf) # Write all data to disk t.flush() # Nothing to do when the database on the server is up to date except NotModifiedError: log.info("Local database is up to date") return # Catch decompression errors except lzma.LZMAError as e: log.warning("Could not decompress downloaded file: %s" % e) continue # XXX what do we catch here? except urllib.error.HTTPError as e: if e.code == 404: continue # Truncate the target file and drop downloaded content try: t.truncate() except OSError: pass raise e # Return temporary file return t raise FileNotFoundError(url) class CLI(object): def __init__(self): self.downloader = Downloader(mirrors=MIRRORS) def parse_cli(self): parser = argparse.ArgumentParser( description=_("Location Downloader Command Line Interface"), ) subparsers = parser.add_subparsers() # Global configuration flags parser.add_argument("--debug", action="store_true", help=_("Enable debug output")) # version parser.add_argument("--version", action="version", version="%%(prog)s %s" % location.__version__) # database parser.add_argument("--database", "-d", default="@databasedir@/database.db", help=_("Path to database"), ) # Update update = subparsers.add_parser("update", help=_("Update database")) update.set_defaults(func=self.handle_update) args = parser.parse_args() # Enable debug logging if args.debug: log.setLevel(logging.DEBUG) # Print usage if no action was given if not "func" in args: parser.print_usage() sys.exit(2) return args def run(self): # Parse command line arguments args = self.parse_cli() # Call function ret = args.func(args) # Return with exit code if ret: sys.exit(ret) # Otherwise just exit sys.exit(0) def handle_update(self, ns): mtime = None # Open database try: db = location.Database(ns.database) # Get mtime of the old file mtime = os.path.getmtime(ns.database) except FileNotFoundError as e: db = None # Try downloading a new database try: t = self.downloader.download(DATABASE_FILENAME, mtime=mtime) # If no file could be downloaded, log a message except FileNotFoundError as e: log.error("Could not download a new database") return 1 # If we have not received a new file, there is nothing to do if not t: return 0 # Save old database creation time created_at = db.created_at if db else 0 # Try opening the downloaded file try: db = location.Database(t.name) except Exception as e: raise e # Check if the downloaded file is newer if db.created_at <= created_at: log.warning("Downloaded database is older than the current version") return 1 log.info("Downloaded new database from %s" % (time.strftime( "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at), ))) # Write temporary file to destination shutil.copyfile(t.name, ns.database) # Remove temporary file os.unlink(t.name) def main(): # Run the command line interface c = CLI() c.run() main()