]> git.ipfire.org Git - location/libloc.git/commitdiff
Move location-downloader functionality into location-query
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 3 Jun 2020 17:06:13 +0000 (17:06 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 3 Jun 2020 17:06:13 +0000 (17:06 +0000)
The commands are very long and confusion. Hence we merge this
all into one command.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/python/downloader.py [moved from src/python/location-downloader.in with 60% similarity]
src/python/location-query.in
src/python/locationmodule.c

index 31869e0ca61448e7f7777a93e005399103be45af..c0b1300eb27ecd52c0ab1c541de17e9a18feae05 100644 (file)
@@ -146,6 +146,7 @@ CLEANFILES += \
 dist_pkgpython_PYTHON = \
        src/python/__init__.py \
        src/python/database.py \
+       src/python/downloader.py \
        src/python/i18n.py \
        src/python/importer.py \
        src/python/logger.py
@@ -234,19 +235,16 @@ uninstall-perl:
                $(DESTDIR)/$(prefix)/man/man3/Location.3pm
 
 bin_SCRIPTS = \
-       src/python/location-downloader \
        src/python/location-exporter \
        src/python/location-importer \
        src/python/location-query
 
 EXTRA_DIST += \
-       src/python/location-downloader.in \
        src/python/location-exporter.in \
        src/python/location-importer.in \
        src/python/location-query.in
 
 CLEANFILES += \
-       src/python/location-downloader \
        src/python/location-exporter \
        src/python/location-importer \
        src/python/location-query
similarity index 60%
rename from src/python/location-downloader.in
rename to src/python/downloader.py
index bf0d682467e469f0b6188742447c5b35898c3145..c9e6e0033d7be925def3e176de02492fa98da64c 100644 (file)
@@ -3,7 +3,7 @@
 #                                                                             #
 # libloc - A library to determine the location of someone on the Internet     #
 #                                                                             #
-# Copyright (C) 2019 IPFire Development Team <info@ipfire.org>                #
+# Copyright (C) 2020 IPFire Development Team <info@ipfire.org>                #
 #                                                                             #
 # This library is free software; you can redistribute it and/or               #
 # modify it under the terms of the GNU Lesser General Public                  #
 #                                                                             #
 ###############################################################################
 
-import argparse
-import datetime
 import logging
 import lzma
 import os
 import random
-import shutil
 import stat
-import sys
 import tempfile
 import time
 import urllib.error
 import urllib.parse
 import urllib.request
 
-# Load our location module
-import location
-from location.i18n import _
+from _location import Database, DATABASE_VERSION_LATEST
 
 DATABASE_FILENAME = "location.db.xz"
 MIRRORS = (
@@ -46,9 +40,11 @@ log = logging.getLogger("location.downloader")
 log.propagate = 1
 
 class Downloader(object):
-       def __init__(self, version, mirrors):
+       def __init__(self, version=DATABASE_VERSION_LATEST, mirrors=None):
                self.version = version
-               self.mirrors = list(mirrors)
+
+               # Set mirrors or use defaults
+               self.mirrors = list(mirrors or MIRRORS)
 
                # Randomize mirrors
                random.shuffle(self.mirrors)
@@ -117,9 +113,10 @@ class Downloader(object):
 
                return res
 
-       def download(self, url, public_key, timestamp=None, tmpdir=None, **kwargs):
-               headers = {}
+       def download(self, public_key, timestamp=None, tmpdir=None, **kwargs):
+               url = "%s/%s" % (self.version, DATABASE_FILENAME)
 
+               headers = {}
                if timestamp:
                        headers["If-Modified-Since"] = timestamp.strftime(
                                "%a, %d %b %Y %H:%M:%S GMT",
@@ -191,7 +188,7 @@ class Downloader(object):
                """
                log.debug("Opening downloaded database at %s" % f.name)
 
-               db = location.Database(f.name)
+               db = Database(f.name)
 
                # Database is not recent
                if timestamp and db.created_at < timestamp.timestamp():
@@ -208,141 +205,3 @@ class Downloader(object):
                                return False
 
                return True
-
-
-class CLI(object):
-       def __init__(self):
-               # Which version are we downloading?
-               self.version = location.DATABASE_VERSION_LATEST
-
-               self.downloader = Downloader(version=self.version, mirrors=MIRRORS)
-
-       def parse_cli(self):
-               parser = argparse.ArgumentParser(
-                       description=_("Location Downloader Command Line Interface"),
-               )
-               subparsers = parser.add_subparsers()
-
-               # Global configuration flags
-               parser.add_argument("--debug", action="store_true",
-                       help=_("Enable debug output"))
-               parser.add_argument("--quiet", action="store_true",
-                       help=_("Enable quiet mode"))
-
-               # version
-               parser.add_argument("--version", action="version",
-                       version="%(prog)s @VERSION@")
-
-               # database
-               parser.add_argument("--database", "-d",
-                       default="@databasedir@/database.db", help=_("Path to database"),
-               )
-
-               # public key
-               parser.add_argument("--public-key", "-k",
-                       default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
-               )
-
-               # Update
-               update = subparsers.add_parser("update", help=_("Update database"))
-               update.set_defaults(func=self.handle_update)
-
-               # Verify
-               verify = subparsers.add_parser("verify",
-                       help=_("Verify the downloaded database"))
-               verify.set_defaults(func=self.handle_verify)
-
-               args = parser.parse_args()
-
-               # Configure logging
-               if args.debug:
-                       location.logger.set_level(logging.DEBUG)
-               elif args.quiet:
-                       location.logger.set_level(logging.WARNING)
-
-               # Print usage if no action was given
-               if not "func" in args:
-                       parser.print_usage()
-                       sys.exit(2)
-
-               return args
-
-       def run(self):
-               # Parse command line arguments
-               args = self.parse_cli()
-
-               # Call function
-               ret = args.func(args)
-
-               # Return with exit code
-               if ret:
-                       sys.exit(ret)
-
-               # Otherwise just exit
-               sys.exit(0)
-
-       def handle_update(self, ns):
-               # Fetch the timestamp we need from DNS
-               t = location.discover_latest_version(self.version)
-
-               # Parse timestamp into datetime format
-               timestamp = datetime.datetime.fromtimestamp(t) if t else None
-
-               # Open database
-               try:
-                       db = location.Database(ns.database)
-
-                       # Check if we are already on the latest version
-                       if timestamp and db.created_at >= timestamp.timestamp():
-                               log.info("Already on the latest version")
-                               return
-
-               except FileNotFoundError as e:
-                       db = None
-
-               # Download the database into the correct directory
-               tmpdir = os.path.dirname(ns.database)
-
-               # Try downloading a new database
-               try:
-                       t = self.downloader.download("%s/%s" % (self.version, DATABASE_FILENAME),
-                               public_key=ns.public_key, timestamp=timestamp, tmpdir=tmpdir)
-
-               # If no file could be downloaded, log a message
-               except FileNotFoundError as e:
-                       log.error("Could not download a new database")
-                       return 1
-
-               # If we have not received a new file, there is nothing to do
-               if not t:
-                       return 3
-
-               # Move temporary file to destination
-               shutil.move(t.name, ns.database)
-
-               return 0
-
-       def handle_verify(self, ns):
-               try:
-                       db = location.Database(ns.database)
-               except FileNotFoundError as e:
-                       log.error("%s: %s" % (ns.database, e))
-                       return 127
-
-               # Verify the database
-               with open(ns.public_key, "r") as f:
-                       if not db.verify(f):
-                               log.error("Could not verify database")
-                               return 1
-
-               # Success
-               log.debug("Database successfully verified")
-               return 0
-
-
-def main():
-       # Run the command line interface
-       c = CLI()
-       c.run()
-
-main()
index dfdff8c2b804b08732d7734c310a1614cc35d3e9..02917861ec432f1f6a31e434162441dc06348624 100644 (file)
 ###############################################################################
 
 import argparse
+import datetime
 import ipaddress
+import logging
 import os
+import shutil
 import socket
 import sys
 import time
 
 # Load our location module
 import location
+import location.downloader
 from location.i18n import _
 
+# Setup logging
+log = logging.getLogger("location")
+
 # Output formatters
 
 class OutputFormatter(object):
@@ -157,6 +164,15 @@ class CLI(object):
                dump.add_argument("output", nargs="?", type=argparse.FileType("w"))
                dump.set_defaults(func=self.handle_dump)
 
+               # Update
+               update = subparsers.add_parser("update", help=_("Update database"))
+               update.set_defaults(func=self.handle_update)
+
+               # Verify
+               verify = subparsers.add_parser("verify",
+                       help=_("Verify the downloaded database"))
+               verify.set_defaults(func=self.handle_verify)
+
                # Get AS
                get_as = subparsers.add_parser("get-as",
                        help=_("Get information about one or multiple Autonomous Systems"),
@@ -423,6 +439,59 @@ class CLI(object):
                        for a in db.search_as(query):
                                print(a)
 
+       def handle_update(self, db, ns):
+               # Fetch the timestamp we need from DNS
+               t = location.discover_latest_version()
+
+               # Parse timestamp into datetime format
+               timestamp = datetime.datetime.fromtimestamp(t) if t else None
+
+               # Check the version of the local database
+               if db and timestamp and db.created_at >= timestamp.timestamp():
+                       log.info("Already on the latest version")
+                       return
+
+               # Download the database into the correct directory
+               tmpdir = os.path.dirname(ns.database)
+
+               # Create a downloader
+               d = location.downloader.Downloader()
+
+               # Try downloading a new database
+               try:
+                       t = d.download(public_key=ns.public_key, timestamp=timestamp, tmpdir=tmpdir)
+
+               # If no file could be downloaded, log a message
+               except FileNotFoundError as e:
+                       log.error("Could not download a new database")
+                       return 1
+
+               # If we have not received a new file, there is nothing to do
+               if not t:
+                       return 3
+
+               # Move temporary file to destination
+               shutil.move(t.name, ns.database)
+
+               return 0
+
+       def handle_verify(self, ns):
+               try:
+                       db = location.Database(ns.database)
+               except FileNotFoundError as e:
+                       log.error("%s: %s" % (ns.database, e))
+                       return 127
+
+               # Verify the database
+               with open(ns.public_key, "r") as f:
+                       if not db.verify(f):
+                               log.error("Could not verify database")
+                               return 1
+
+               # Success
+               log.debug("Database successfully verified")
+               return 0
+
        def __get_output_formatter(self, ns):
                try:
                        cls = self.output_formats[ns.output_format]
index a04cab7f0c7f5c415f02bfb1257cc34985b7b73d..5b72be9b0f9fbfe1df62dc00064d52691a51d1bd 100644 (file)
@@ -50,9 +50,9 @@ static PyObject* set_log_level(PyObject* m, PyObject* args) {
 }
 
 static PyObject* discover_latest_version(PyObject* m, PyObject* args) {
-       unsigned int version = 0;
+       unsigned int version = LOC_DATABASE_VERSION_LATEST;
 
-       if (!PyArg_ParseTuple(args, "i", &version))
+       if (!PyArg_ParseTuple(args, "|i", &version))
                return NULL;
 
        time_t t = 0;