]> git.ipfire.org Git - location/libloc.git/commitdiff
location-query: Allow passing the database path
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 2 Oct 2019 13:55:33 +0000 (13:55 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 2 Oct 2019 13:55:33 +0000 (13:55 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/location-query.in

index 1433f1aacd1a559bbcbc643da4cad6dc739b6fd8..7a7492cc024ebdbfe6df43737d03262ed8a3a17b 100644 (file)
@@ -33,10 +33,6 @@ def _(singular, plural=None, n=None):
        return gettext.dgettext("libloc", singular)
 
 class CLI(object):
-       def __init__(self):
-               # Open database
-               self.db = location.Database("@databasedir@/database.db")
-
        def parse_cli(self):
                parser = argparse.ArgumentParser(
                        description=_("Location Database Command Line Interface"),
@@ -51,6 +47,11 @@ class CLI(object):
                parser.add_argument("--version", action="version",
                        version="%%(prog)s %s" % location.__version__)
 
+               # database
+               parser.add_argument("--database", "-d",
+                       default="@databasedir@/database.db", help=_("Path to database"),
+               )
+
                # lookup an IP address
                lookup = subparsers.add_parser("lookup",
                        help=_("Lookup one or multiple IP addresses"),
@@ -81,8 +82,16 @@ class CLI(object):
                # Callback function must be defined
                assert args.func, "Callback function not defined"
 
+               # Open database
+               try:
+                       db = location.Database(args.database)
+               except FileNotFoundError as e:
+                       sys.stderr.write("location-query: Could not open database %s: %s\n" \
+                               % (args.database, e))
+                       sys.exit(1)
+
                # Call function
-               ret = args.func(args)
+               ret = args.func(db, args)
 
                # Return with exit code
                if ret:
@@ -91,12 +100,12 @@ class CLI(object):
                # Otherwise just exit
                sys.exit(0)
 
-       def handle_lookup(self, ns):
+       def handle_lookup(self, db, ns):
                ret = 0
 
                for address in ns.address:
                        try:
-                               n = self.db.lookup(address)
+                               n = db.lookup(address)
                        except ValueError:
                                print(_("Invalid IP address: %s") % address, file=sys.stderr)
 
@@ -113,7 +122,7 @@ class CLI(object):
 
                        # Try to retrieve the AS if we have an AS number
                        if n.asn:
-                               a = self.db.get_as(n.asn)
+                               a = db.get_as(n.asn)
 
                                # If we have found an AS we will print it in the message
                                if a:
@@ -128,7 +137,7 @@ class CLI(object):
 
                return ret
 
-       def handle_get_as(self, ns):
+       def handle_get_as(self, db, ns):
                """
                        Gets information about Autonomous Systems
                """
@@ -143,7 +152,7 @@ class CLI(object):
                                continue
 
                        # Fetch AS from database
-                       a = self.db.get_as(asn)
+                       a = db.get_as(asn)
 
                        # Nothing found
                        if not a:
@@ -155,10 +164,10 @@ class CLI(object):
 
                return ret
 
-       def handle_search_as(self, ns):
+       def handle_search_as(self, db, ns):
                for query in ns.query:
                        # Print all matches ASes
-                       for a in self.db.search_as(query):
+                       for a in db.search_as(query):
                                print(a)
 
 def main():